ynuozhang
commited on
Commit
·
baf3373
1
Parent(s):
b8c6018
update code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- metrics/nonfouling/train_predictions_binary.csv +0 -0
- metrics/nonfouling/val_predictions_binary.csv +3 -3438
- tokenizer/.ipynb_checkpoints/my_tokenizers-checkpoint.py +398 -0
- tokenizer/__pycache__/my_tokenizers.cpython-310.pyc +0 -0
- tokenizer/my_tokenizers.py +398 -0
- tokenizer/new_splits.txt +159 -0
- tokenizer/new_vocab.txt +586 -0
- training_classifiers/.gitignore +0 -0
- training_classifiers/.ipynb_checkpoints/binding_affinity_iptm-checkpoint.py +132 -0
- training_classifiers/.ipynb_checkpoints/binding_affinity_split-checkpoint.py +847 -0
- training_classifiers/.ipynb_checkpoints/binding_training-checkpoint.py +414 -0
- training_classifiers/.ipynb_checkpoints/binding_wt-checkpoint.bash +31 -0
- training_classifiers/.ipynb_checkpoints/finetune_boost-checkpoint.py +508 -0
- training_classifiers/.ipynb_checkpoints/generate_binding_val-checkpoint.py +309 -0
- training_classifiers/.ipynb_checkpoints/peptiverse_filelist-checkpoint.txt +234 -0
- training_classifiers/.ipynb_checkpoints/train_boost-checkpoint.py +417 -0
- training_classifiers/.ipynb_checkpoints/train_ml-checkpoint.py +468 -0
- training_classifiers/.ipynb_checkpoints/train_ml_regression-checkpoint.py +410 -0
- training_classifiers/.ipynb_checkpoints/train_nn-checkpoint.py +426 -0
- training_classifiers/.ipynb_checkpoints/train_nn_regression-checkpoint.py +420 -0
- training_data_cleaned/data_split.ipynb → training_classifiers/binding_affinity/val_smiles_pooled.csv +2 -2
- training_data_cleaned/nf_smiles_train.csv → training_classifiers/binding_affinity/val_smiles_unpooled.csv +2 -2
- training_data_cleaned/smiles_data_split.ipynb → training_classifiers/binding_affinity/val_wt_pooled.csv +2 -2
- training_data_cleaned/nf_smiles_val.csv → training_classifiers/binding_affinity/val_wt_unpooled.csv +2 -2
- training_classifiers/binding_affinity/wt_smiles_pooled/best_model.pt +3 -0
- training_classifiers/binding_affinity/wt_smiles_pooled/best_params.json +10 -0
- training_classifiers/binding_affinity/wt_smiles_pooled/optuna_trials.csv +3 -0
- training_classifiers/binding_affinity/wt_smiles_unpooled/.ipynb_checkpoints/best_params-checkpoint.json +10 -0
- training_classifiers/binding_affinity/wt_smiles_unpooled/best_model.pt +3 -0
- training_classifiers/binding_affinity/wt_smiles_unpooled/best_params.json +10 -0
- training_classifiers/binding_affinity/wt_smiles_unpooled/optuna_trials.csv +3 -0
- training_classifiers/binding_affinity/wt_wt_pooled/.ipynb_checkpoints/optuna_trials-checkpoint.csv +3 -0
- training_classifiers/binding_affinity/wt_wt_pooled/best_model.pt +3 -0
- training_classifiers/binding_affinity/wt_wt_pooled/best_params.json +10 -0
- training_classifiers/binding_affinity/wt_wt_pooled/optuna_trials.csv +3 -0
- training_classifiers/binding_affinity/wt_wt_unpooled/best_model.pt +3 -0
- training_classifiers/binding_affinity/wt_wt_unpooled/best_params.json +10 -0
- training_classifiers/binding_affinity/wt_wt_unpooled/optuna_trials.csv +3 -0
- training_classifiers/binding_affinity_iptm.py +132 -0
- training_classifiers/binding_affinity_split.py +847 -0
- training_classifiers/binding_training.py +414 -0
- training_classifiers/binding_wt.bash +31 -0
- training_classifiers/hemolysis/cnn_smiles/best_model.pt +3 -0
- training_classifiers/hemolysis/cnn_smiles/best_model_benchmark.json +39 -0
- training_classifiers/hemolysis/cnn_smiles/optimization_summary.txt +19 -0
- training_classifiers/hemolysis/cnn_smiles/pr_curve.png +0 -0
- training_classifiers/hemolysis/cnn_smiles/roc_curve.png +0 -0
- training_classifiers/hemolysis/cnn_smiles/study_trials.csv +3 -0
- training_classifiers/hemolysis/cnn_smiles/train_predictions.csv +3 -0
- training_classifiers/hemolysis/cnn_smiles/val_predictions.csv +3 -0
metrics/nonfouling/train_predictions_binary.csv
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
metrics/nonfouling/val_predictions_binary.csv
CHANGED
|
@@ -1,3438 +1,3 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
0,0.0048168427,0
|
| 5 |
-
0,0.0016095717,0
|
| 6 |
-
0,0.0010283176,0
|
| 7 |
-
1,0.6661874,1
|
| 8 |
-
0,0.25750402,0
|
| 9 |
-
0,0.4311336,0
|
| 10 |
-
0,0.00085044367,0
|
| 11 |
-
0,0.0011397039,0
|
| 12 |
-
0,0.00088075985,0
|
| 13 |
-
0,0.0020098046,0
|
| 14 |
-
0,0.0010203379,0
|
| 15 |
-
0,0.0008713204,0
|
| 16 |
-
0,0.0017780334,0
|
| 17 |
-
0,0.00342609,0
|
| 18 |
-
0,0.00080001954,0
|
| 19 |
-
0,0.0027687384,0
|
| 20 |
-
0,0.5390543,1
|
| 21 |
-
0,0.8238043,1
|
| 22 |
-
0,0.0012005109,0
|
| 23 |
-
0,0.00097379234,0
|
| 24 |
-
0,0.10598443,0
|
| 25 |
-
0,0.25473288,0
|
| 26 |
-
0,0.0013989281,0
|
| 27 |
-
1,0.8493596,1
|
| 28 |
-
0,0.3361668,0
|
| 29 |
-
0,0.0011108345,0
|
| 30 |
-
0,0.00085180654,0
|
| 31 |
-
0,0.57711107,1
|
| 32 |
-
0,0.6323541,1
|
| 33 |
-
0,0.00092630496,0
|
| 34 |
-
0,0.000880342,0
|
| 35 |
-
1,0.7630482,1
|
| 36 |
-
0,0.14976597,0
|
| 37 |
-
1,0.6572999,1
|
| 38 |
-
0,0.0013483333,0
|
| 39 |
-
0,0.0010882191,0
|
| 40 |
-
0,0.0009723114,0
|
| 41 |
-
0,0.00078559143,0
|
| 42 |
-
0,0.33249167,0
|
| 43 |
-
0,0.0010722216,0
|
| 44 |
-
0,0.00078953034,0
|
| 45 |
-
1,0.6428388,1
|
| 46 |
-
1,0.24572088,0
|
| 47 |
-
1,0.6876133,1
|
| 48 |
-
0,0.20859648,0
|
| 49 |
-
0,0.0020129737,0
|
| 50 |
-
0,0.56601,1
|
| 51 |
-
0,0.0012045301,0
|
| 52 |
-
0,0.0010913766,0
|
| 53 |
-
0,0.00096886675,0
|
| 54 |
-
1,0.59327847,1
|
| 55 |
-
0,0.24701594,0
|
| 56 |
-
0,0.00094282435,0
|
| 57 |
-
0,0.00080143235,0
|
| 58 |
-
0,0.00093203504,0
|
| 59 |
-
1,0.38815123,0
|
| 60 |
-
0,0.1851648,0
|
| 61 |
-
0,0.0012724681,0
|
| 62 |
-
0,0.5877677,1
|
| 63 |
-
0,0.00086790195,0
|
| 64 |
-
0,0.00084711233,0
|
| 65 |
-
0,0.00089334225,0
|
| 66 |
-
0,0.07253498,0
|
| 67 |
-
0,0.003544662,0
|
| 68 |
-
0,0.06225388,0
|
| 69 |
-
0,0.5347445,1
|
| 70 |
-
0,0.0015253695,0
|
| 71 |
-
0,0.32455897,0
|
| 72 |
-
0,0.0011612711,0
|
| 73 |
-
0,0.7157994,1
|
| 74 |
-
0,0.0011453595,0
|
| 75 |
-
0,0.001003294,0
|
| 76 |
-
0,0.0008531371,0
|
| 77 |
-
1,0.69131196,1
|
| 78 |
-
0,0.00073970045,0
|
| 79 |
-
0,0.00097454654,0
|
| 80 |
-
0,0.0009003313,0
|
| 81 |
-
1,0.6489763,1
|
| 82 |
-
0,0.0010982461,0
|
| 83 |
-
0,0.00080810016,0
|
| 84 |
-
0,0.0012549972,0
|
| 85 |
-
1,0.44953474,0
|
| 86 |
-
0,0.20876294,0
|
| 87 |
-
0,0.73931056,1
|
| 88 |
-
0,0.0009876546,0
|
| 89 |
-
0,0.29462275,0
|
| 90 |
-
0,0.39917734,0
|
| 91 |
-
0,0.000782265,0
|
| 92 |
-
0,0.42951488,0
|
| 93 |
-
0,0.001463046,0
|
| 94 |
-
1,0.68715686,1
|
| 95 |
-
1,0.42291453,0
|
| 96 |
-
0,0.0022332973,0
|
| 97 |
-
0,0.0009480977,0
|
| 98 |
-
0,0.0024130554,0
|
| 99 |
-
0,0.06350319,0
|
| 100 |
-
1,0.66712207,1
|
| 101 |
-
0,0.00079292513,0
|
| 102 |
-
0,0.0009931386,0
|
| 103 |
-
0,0.6005864,1
|
| 104 |
-
0,0.0008592792,0
|
| 105 |
-
0,0.12034535,0
|
| 106 |
-
1,0.6524521,1
|
| 107 |
-
0,0.0009471925,0
|
| 108 |
-
0,0.04952685,0
|
| 109 |
-
0,0.081261575,0
|
| 110 |
-
0,0.00092108996,0
|
| 111 |
-
0,0.43716368,0
|
| 112 |
-
0,0.0010192527,0
|
| 113 |
-
0,0.0009138947,0
|
| 114 |
-
0,0.00087209453,0
|
| 115 |
-
0,0.0007799222,0
|
| 116 |
-
1,0.71242714,1
|
| 117 |
-
0,0.0009697011,0
|
| 118 |
-
0,0.25283346,0
|
| 119 |
-
0,0.0033405088,0
|
| 120 |
-
0,0.0009034709,0
|
| 121 |
-
0,0.0027930748,0
|
| 122 |
-
0,0.0011101163,0
|
| 123 |
-
1,0.34234285,0
|
| 124 |
-
0,0.0009976418,0
|
| 125 |
-
0,0.0009782554,0
|
| 126 |
-
0,0.0010062164,0
|
| 127 |
-
0,0.0012275511,0
|
| 128 |
-
0,0.0007695558,0
|
| 129 |
-
0,0.0019192833,0
|
| 130 |
-
0,0.15324375,0
|
| 131 |
-
0,0.0008446999,0
|
| 132 |
-
0,0.0010133649,0
|
| 133 |
-
0,0.00087288895,0
|
| 134 |
-
0,0.1176671,0
|
| 135 |
-
0,0.0011782635,0
|
| 136 |
-
0,0.4397572,0
|
| 137 |
-
0,0.0007951967,0
|
| 138 |
-
0,0.0012110751,0
|
| 139 |
-
0,0.00088101206,0
|
| 140 |
-
0,0.12597291,0
|
| 141 |
-
1,0.5714519,1
|
| 142 |
-
1,0.50501525,1
|
| 143 |
-
0,0.0012025698,0
|
| 144 |
-
0,0.5724967,1
|
| 145 |
-
0,0.09112893,0
|
| 146 |
-
0,0.0013447981,0
|
| 147 |
-
0,0.0011533678,0
|
| 148 |
-
1,0.57831836,1
|
| 149 |
-
0,0.48726425,0
|
| 150 |
-
0,0.0011903379,0
|
| 151 |
-
0,0.00084324833,0
|
| 152 |
-
0,0.00082829024,0
|
| 153 |
-
0,0.085955195,0
|
| 154 |
-
0,0.0008221822,0
|
| 155 |
-
1,0.62516063,1
|
| 156 |
-
0,0.0010661274,0
|
| 157 |
-
0,0.47287834,0
|
| 158 |
-
0,0.30710956,0
|
| 159 |
-
0,0.01076754,0
|
| 160 |
-
1,0.21308176,0
|
| 161 |
-
1,0.7633791,1
|
| 162 |
-
0,0.45939833,0
|
| 163 |
-
0,0.0010105146,0
|
| 164 |
-
1,0.7428216,1
|
| 165 |
-
0,0.15860216,0
|
| 166 |
-
1,0.2822227,0
|
| 167 |
-
0,0.0011596903,0
|
| 168 |
-
0,0.00090244703,0
|
| 169 |
-
0,0.0011741482,0
|
| 170 |
-
1,0.53843015,1
|
| 171 |
-
0,0.0031817604,0
|
| 172 |
-
0,0.0009357714,0
|
| 173 |
-
0,0.00084562885,0
|
| 174 |
-
0,0.40793115,0
|
| 175 |
-
0,0.0009336455,0
|
| 176 |
-
0,0.0012610556,0
|
| 177 |
-
0,0.0009685405,0
|
| 178 |
-
0,0.0008348695,0
|
| 179 |
-
0,0.00084012165,0
|
| 180 |
-
0,0.0011492906,0
|
| 181 |
-
0,0.0009675606,0
|
| 182 |
-
0,0.6298985,1
|
| 183 |
-
0,0.756409,1
|
| 184 |
-
0,0.0012567933,0
|
| 185 |
-
0,0.67586565,1
|
| 186 |
-
0,0.00087434775,0
|
| 187 |
-
0,0.49520925,0
|
| 188 |
-
0,0.00083109684,0
|
| 189 |
-
0,0.0015268176,0
|
| 190 |
-
0,0.0009127699,0
|
| 191 |
-
0,0.47751015,0
|
| 192 |
-
0,0.0009006195,0
|
| 193 |
-
0,0.0015889019,0
|
| 194 |
-
0,0.27180874,0
|
| 195 |
-
1,0.6914706,1
|
| 196 |
-
0,0.789135,1
|
| 197 |
-
0,0.0009752872,0
|
| 198 |
-
0,0.000897456,0
|
| 199 |
-
0,0.0011121263,0
|
| 200 |
-
0,0.0015918579,0
|
| 201 |
-
0,0.6213229,1
|
| 202 |
-
0,0.6309986,1
|
| 203 |
-
0,0.0020146987,0
|
| 204 |
-
0,0.4184653,0
|
| 205 |
-
0,0.0020128626,0
|
| 206 |
-
0,0.21320935,0
|
| 207 |
-
0,0.0008362684,0
|
| 208 |
-
0,0.001030758,0
|
| 209 |
-
0,0.0011157958,0
|
| 210 |
-
0,0.0009790816,0
|
| 211 |
-
1,0.7094025,1
|
| 212 |
-
1,0.9093414,1
|
| 213 |
-
0,0.17732719,0
|
| 214 |
-
0,0.0009374656,0
|
| 215 |
-
0,0.00094623194,0
|
| 216 |
-
1,0.8687336,1
|
| 217 |
-
0,0.65683585,1
|
| 218 |
-
0,0.001121003,0
|
| 219 |
-
0,0.590965,1
|
| 220 |
-
0,0.6421117,1
|
| 221 |
-
0,0.0010331942,0
|
| 222 |
-
0,0.0009423399,0
|
| 223 |
-
0,0.26386958,0
|
| 224 |
-
0,0.0009064185,0
|
| 225 |
-
0,0.016375644,0
|
| 226 |
-
0,0.001191659,0
|
| 227 |
-
0,0.17972796,0
|
| 228 |
-
0,0.8418873,1
|
| 229 |
-
0,0.0009111323,0
|
| 230 |
-
0,0.26404643,0
|
| 231 |
-
0,0.0010872352,0
|
| 232 |
-
0,0.0017061317,0
|
| 233 |
-
0,0.81021,1
|
| 234 |
-
0,0.00087346835,0
|
| 235 |
-
0,0.14280483,0
|
| 236 |
-
0,0.16507958,0
|
| 237 |
-
0,0.0008098925,0
|
| 238 |
-
0,0.00083044183,0
|
| 239 |
-
0,0.0012201187,0
|
| 240 |
-
0,0.0010541711,0
|
| 241 |
-
0,0.28025955,0
|
| 242 |
-
0,0.0012255623,0
|
| 243 |
-
0,0.0010174535,0
|
| 244 |
-
0,0.0016367439,0
|
| 245 |
-
0,0.73789763,1
|
| 246 |
-
0,0.32007608,0
|
| 247 |
-
0,0.12180342,0
|
| 248 |
-
0,0.0015864924,0
|
| 249 |
-
0,0.70650285,1
|
| 250 |
-
0,0.00086796284,0
|
| 251 |
-
0,0.0009819808,0
|
| 252 |
-
0,0.099008076,0
|
| 253 |
-
0,0.19955282,0
|
| 254 |
-
1,0.46594706,0
|
| 255 |
-
0,0.00096476724,0
|
| 256 |
-
1,0.7014958,1
|
| 257 |
-
0,0.6206068,1
|
| 258 |
-
1,0.40477422,0
|
| 259 |
-
0,0.0010020116,0
|
| 260 |
-
0,0.55101025,1
|
| 261 |
-
0,0.0010169167,0
|
| 262 |
-
0,0.0011705819,0
|
| 263 |
-
0,0.43195137,0
|
| 264 |
-
0,0.33231077,0
|
| 265 |
-
0,0.0010730977,0
|
| 266 |
-
1,0.6236388,1
|
| 267 |
-
1,0.8430124,1
|
| 268 |
-
0,0.0010129493,0
|
| 269 |
-
1,0.55188674,1
|
| 270 |
-
0,0.2805052,0
|
| 271 |
-
0,0.49679318,0
|
| 272 |
-
0,0.00085314247,0
|
| 273 |
-
0,0.0009600679,0
|
| 274 |
-
0,0.000983881,0
|
| 275 |
-
0,0.0010313862,0
|
| 276 |
-
0,0.0010591929,0
|
| 277 |
-
0,0.0014554731,0
|
| 278 |
-
0,0.2721196,0
|
| 279 |
-
0,0.0010521681,0
|
| 280 |
-
0,0.0008385733,0
|
| 281 |
-
0,0.0010760311,0
|
| 282 |
-
0,0.56204385,1
|
| 283 |
-
0,0.0009945527,0
|
| 284 |
-
0,0.6538143,1
|
| 285 |
-
0,0.000972335,0
|
| 286 |
-
0,0.0015788077,0
|
| 287 |
-
0,0.00078132324,0
|
| 288 |
-
0,0.0009190443,0
|
| 289 |
-
0,0.0009206846,0
|
| 290 |
-
0,0.14582296,0
|
| 291 |
-
0,0.0009473762,0
|
| 292 |
-
0,0.00079042354,0
|
| 293 |
-
0,0.0010401446,0
|
| 294 |
-
0,0.3425446,0
|
| 295 |
-
1,0.7118903,1
|
| 296 |
-
0,0.0011929816,0
|
| 297 |
-
0,0.8595175,1
|
| 298 |
-
0,0.5998442,1
|
| 299 |
-
0,0.0008737177,0
|
| 300 |
-
1,0.80021775,1
|
| 301 |
-
0,0.0067423894,0
|
| 302 |
-
0,0.37208188,0
|
| 303 |
-
0,0.0009025443,0
|
| 304 |
-
1,0.4007288,0
|
| 305 |
-
0,0.00084359787,0
|
| 306 |
-
0,0.010245014,0
|
| 307 |
-
0,0.0009981133,0
|
| 308 |
-
0,0.0007901667,0
|
| 309 |
-
1,0.8616254,1
|
| 310 |
-
1,0.37549037,0
|
| 311 |
-
0,0.0010908762,0
|
| 312 |
-
1,0.82627475,1
|
| 313 |
-
0,0.7775751,1
|
| 314 |
-
1,0.64397454,1
|
| 315 |
-
0,0.000756293,0
|
| 316 |
-
0,0.000830868,0
|
| 317 |
-
0,0.0008689119,0
|
| 318 |
-
0,0.21817225,0
|
| 319 |
-
1,0.54163814,1
|
| 320 |
-
0,0.001031394,0
|
| 321 |
-
0,0.4650486,0
|
| 322 |
-
1,0.42690405,0
|
| 323 |
-
0,0.000841264,0
|
| 324 |
-
0,0.003041604,0
|
| 325 |
-
0,0.00094686233,0
|
| 326 |
-
0,0.0120981,0
|
| 327 |
-
1,0.12877595,0
|
| 328 |
-
0,0.0010682141,0
|
| 329 |
-
0,0.26241425,0
|
| 330 |
-
0,0.00917543,0
|
| 331 |
-
0,0.4467414,0
|
| 332 |
-
0,0.00091395644,0
|
| 333 |
-
0,0.0012513435,0
|
| 334 |
-
0,0.6208037,1
|
| 335 |
-
0,0.09325837,0
|
| 336 |
-
0,0.0014404579,0
|
| 337 |
-
0,0.0013578972,0
|
| 338 |
-
0,0.564982,1
|
| 339 |
-
0,0.0011169427,0
|
| 340 |
-
0,0.0013792963,0
|
| 341 |
-
0,0.0019788202,0
|
| 342 |
-
0,0.11213151,0
|
| 343 |
-
0,0.00093545037,0
|
| 344 |
-
0,0.00083710946,0
|
| 345 |
-
0,0.0009821211,0
|
| 346 |
-
0,0.33052954,0
|
| 347 |
-
1,0.63790786,1
|
| 348 |
-
0,0.38441333,0
|
| 349 |
-
0,0.65978384,1
|
| 350 |
-
0,0.16266404,0
|
| 351 |
-
0,0.0009782265,0
|
| 352 |
-
0,0.7020993,1
|
| 353 |
-
1,0.5939269,1
|
| 354 |
-
0,0.00086739147,0
|
| 355 |
-
1,0.80537134,1
|
| 356 |
-
1,0.5931398,1
|
| 357 |
-
0,0.0010550644,0
|
| 358 |
-
0,0.000939566,0
|
| 359 |
-
0,0.001143589,0
|
| 360 |
-
0,0.0015645259,0
|
| 361 |
-
0,0.0010082681,0
|
| 362 |
-
0,0.0011559983,0
|
| 363 |
-
0,0.0030087254,0
|
| 364 |
-
0,0.0019511192,0
|
| 365 |
-
0,0.4628644,0
|
| 366 |
-
0,0.46333745,0
|
| 367 |
-
0,0.0009037835,0
|
| 368 |
-
0,0.0012149046,0
|
| 369 |
-
0,0.0011464095,0
|
| 370 |
-
1,0.53254116,1
|
| 371 |
-
0,0.0011695515,0
|
| 372 |
-
0,0.0011102897,0
|
| 373 |
-
0,0.0015082303,0
|
| 374 |
-
0,0.0009924652,0
|
| 375 |
-
0,0.29110697,0
|
| 376 |
-
0,0.0009876668,0
|
| 377 |
-
0,0.00079951587,0
|
| 378 |
-
0,0.0010723416,0
|
| 379 |
-
0,0.0009410649,0
|
| 380 |
-
0,0.0044702576,0
|
| 381 |
-
0,0.0016339025,0
|
| 382 |
-
1,0.58247274,1
|
| 383 |
-
0,0.00095149525,0
|
| 384 |
-
0,0.0010258186,0
|
| 385 |
-
0,0.00090839405,0
|
| 386 |
-
1,0.5408768,1
|
| 387 |
-
0,0.00091459276,0
|
| 388 |
-
0,0.00096972886,0
|
| 389 |
-
0,0.000929176,0
|
| 390 |
-
0,0.00096466026,0
|
| 391 |
-
0,0.0012642274,0
|
| 392 |
-
0,0.0013000804,0
|
| 393 |
-
1,0.64949554,1
|
| 394 |
-
1,0.824855,1
|
| 395 |
-
0,0.5840037,1
|
| 396 |
-
0,0.0022679865,0
|
| 397 |
-
1,0.70146537,1
|
| 398 |
-
0,0.0008208898,0
|
| 399 |
-
1,0.5012139,1
|
| 400 |
-
0,0.0010077616,0
|
| 401 |
-
0,0.0011317601,0
|
| 402 |
-
1,0.5587163,1
|
| 403 |
-
0,0.0008192366,0
|
| 404 |
-
0,0.00091622194,0
|
| 405 |
-
1,0.26070353,0
|
| 406 |
-
1,0.7721127,1
|
| 407 |
-
0,0.0011609249,0
|
| 408 |
-
1,0.5806655,1
|
| 409 |
-
0,0.001057503,0
|
| 410 |
-
0,0.0009096564,0
|
| 411 |
-
0,0.63499725,1
|
| 412 |
-
1,0.77146506,1
|
| 413 |
-
0,0.00092794717,0
|
| 414 |
-
0,0.0011426172,0
|
| 415 |
-
1,0.6665107,1
|
| 416 |
-
0,0.44717062,0
|
| 417 |
-
1,0.7875412,1
|
| 418 |
-
0,0.6128087,1
|
| 419 |
-
0,0.0018639723,0
|
| 420 |
-
0,0.0011927163,0
|
| 421 |
-
0,0.0011212432,0
|
| 422 |
-
0,0.0010541681,0
|
| 423 |
-
1,0.759266,1
|
| 424 |
-
0,0.000915701,0
|
| 425 |
-
1,0.8248112,1
|
| 426 |
-
1,0.2618734,0
|
| 427 |
-
1,0.5829796,1
|
| 428 |
-
0,0.000971727,0
|
| 429 |
-
0,0.34199846,0
|
| 430 |
-
0,0.22960144,0
|
| 431 |
-
0,0.0008905708,0
|
| 432 |
-
0,0.7192157,1
|
| 433 |
-
0,0.5267322,1
|
| 434 |
-
0,0.0011434992,0
|
| 435 |
-
1,0.82825303,1
|
| 436 |
-
0,0.0007324419,0
|
| 437 |
-
0,0.0009348062,0
|
| 438 |
-
0,0.0018712084,0
|
| 439 |
-
1,0.7346453,1
|
| 440 |
-
0,0.0008732254,0
|
| 441 |
-
0,0.0010398608,0
|
| 442 |
-
1,0.78774214,1
|
| 443 |
-
0,0.0010178161,0
|
| 444 |
-
0,0.40890852,0
|
| 445 |
-
0,0.0007731539,0
|
| 446 |
-
0,0.30410865,0
|
| 447 |
-
1,0.6904336,1
|
| 448 |
-
0,0.0016686685,0
|
| 449 |
-
0,0.17082378,0
|
| 450 |
-
0,0.0019347976,0
|
| 451 |
-
0,0.00089052453,0
|
| 452 |
-
1,0.6660989,1
|
| 453 |
-
0,0.0010408974,0
|
| 454 |
-
0,0.27290353,0
|
| 455 |
-
0,0.009841075,0
|
| 456 |
-
0,0.0012475859,0
|
| 457 |
-
1,0.5256839,1
|
| 458 |
-
1,0.22151299,0
|
| 459 |
-
0,0.00091817125,0
|
| 460 |
-
0,0.5700492,1
|
| 461 |
-
0,0.19963185,0
|
| 462 |
-
0,0.0009827572,0
|
| 463 |
-
0,0.0008537978,0
|
| 464 |
-
0,0.00092485244,0
|
| 465 |
-
0,0.0012399766,0
|
| 466 |
-
0,0.0013511787,0
|
| 467 |
-
0,0.10416922,0
|
| 468 |
-
0,0.0008518869,0
|
| 469 |
-
0,0.000871039,0
|
| 470 |
-
0,0.001035327,0
|
| 471 |
-
0,0.000883175,0
|
| 472 |
-
0,0.0025706466,0
|
| 473 |
-
0,0.29247305,0
|
| 474 |
-
0,0.0008869903,0
|
| 475 |
-
1,0.71239984,1
|
| 476 |
-
1,0.7177157,1
|
| 477 |
-
0,0.0007620353,0
|
| 478 |
-
0,0.0625917,0
|
| 479 |
-
0,0.0009716233,0
|
| 480 |
-
0,0.0010923475,0
|
| 481 |
-
0,0.0009638545,0
|
| 482 |
-
0,0.0014103759,0
|
| 483 |
-
0,0.586873,1
|
| 484 |
-
0,0.1582566,0
|
| 485 |
-
0,0.7623745,1
|
| 486 |
-
0,0.00090825645,0
|
| 487 |
-
0,0.008724699,0
|
| 488 |
-
0,0.0008719578,0
|
| 489 |
-
0,0.0010188158,0
|
| 490 |
-
0,0.0008288342,0
|
| 491 |
-
0,0.00085399643,0
|
| 492 |
-
0,0.0011273371,0
|
| 493 |
-
0,0.62889636,1
|
| 494 |
-
1,0.20183899,0
|
| 495 |
-
0,0.0009553314,0
|
| 496 |
-
0,0.005830987,0
|
| 497 |
-
0,0.10113329,0
|
| 498 |
-
0,0.058883034,0
|
| 499 |
-
0,0.004936521,0
|
| 500 |
-
0,0.001236107,0
|
| 501 |
-
0,0.0008304486,0
|
| 502 |
-
0,0.0012260479,0
|
| 503 |
-
0,0.4102268,0
|
| 504 |
-
1,0.63618875,1
|
| 505 |
-
1,0.33070007,0
|
| 506 |
-
0,0.7466114,1
|
| 507 |
-
0,0.0008505032,0
|
| 508 |
-
0,0.15627518,0
|
| 509 |
-
1,0.53720474,1
|
| 510 |
-
1,0.42872614,0
|
| 511 |
-
0,0.0009015459,0
|
| 512 |
-
1,0.16489983,0
|
| 513 |
-
0,0.7569152,1
|
| 514 |
-
0,0.0009473306,0
|
| 515 |
-
0,0.54220945,1
|
| 516 |
-
0,0.0010804973,0
|
| 517 |
-
0,0.0007759088,0
|
| 518 |
-
0,0.21974401,0
|
| 519 |
-
0,0.0009557337,0
|
| 520 |
-
0,0.00080877467,0
|
| 521 |
-
1,0.72012144,1
|
| 522 |
-
1,0.6555891,1
|
| 523 |
-
0,0.0010442814,0
|
| 524 |
-
0,0.23529597,0
|
| 525 |
-
1,0.5538642,1
|
| 526 |
-
0,0.001103702,0
|
| 527 |
-
1,0.75086635,1
|
| 528 |
-
0,0.0010794887,0
|
| 529 |
-
0,0.00087138265,0
|
| 530 |
-
0,0.45431912,0
|
| 531 |
-
0,0.00098219,0
|
| 532 |
-
1,0.24641904,0
|
| 533 |
-
0,0.001045231,0
|
| 534 |
-
0,0.001125692,0
|
| 535 |
-
0,0.00088083575,0
|
| 536 |
-
1,0.7283503,1
|
| 537 |
-
0,0.40620965,0
|
| 538 |
-
0,0.0009369744,0
|
| 539 |
-
1,0.61685985,1
|
| 540 |
-
0,0.0015938416,0
|
| 541 |
-
0,0.0010618207,0
|
| 542 |
-
1,0.6549626,1
|
| 543 |
-
0,0.0011033998,0
|
| 544 |
-
0,0.0010170939,0
|
| 545 |
-
0,0.0009539079,0
|
| 546 |
-
0,0.0013007914,0
|
| 547 |
-
0,0.0009015459,0
|
| 548 |
-
0,0.0014242147,0
|
| 549 |
-
0,0.707509,1
|
| 550 |
-
0,0.33996958,0
|
| 551 |
-
0,0.249575,0
|
| 552 |
-
0,0.0009841922,0
|
| 553 |
-
0,0.00089403824,0
|
| 554 |
-
0,0.7104041,1
|
| 555 |
-
1,0.7468179,1
|
| 556 |
-
0,0.39707997,0
|
| 557 |
-
0,0.0008614993,0
|
| 558 |
-
0,0.0014179454,0
|
| 559 |
-
0,0.023018427,0
|
| 560 |
-
0,0.00091979245,0
|
| 561 |
-
0,0.00094296545,0
|
| 562 |
-
0,0.00087731396,0
|
| 563 |
-
0,0.0014412756,0
|
| 564 |
-
0,0.0010167825,0
|
| 565 |
-
0,0.002101245,0
|
| 566 |
-
0,0.0032875312,0
|
| 567 |
-
0,0.0008558566,0
|
| 568 |
-
1,0.6307645,1
|
| 569 |
-
0,0.777811,1
|
| 570 |
-
0,0.4272062,0
|
| 571 |
-
0,0.35077578,0
|
| 572 |
-
1,0.63610154,1
|
| 573 |
-
0,0.35240352,0
|
| 574 |
-
0,0.0012432288,0
|
| 575 |
-
0,0.0008819291,0
|
| 576 |
-
0,0.16675064,0
|
| 577 |
-
0,0.0055521415,0
|
| 578 |
-
0,0.0008978255,0
|
| 579 |
-
1,0.5747646,1
|
| 580 |
-
0,0.0009587978,0
|
| 581 |
-
1,0.7340868,1
|
| 582 |
-
0,0.0016573233,0
|
| 583 |
-
0,0.0007863164,0
|
| 584 |
-
0,0.00305028,0
|
| 585 |
-
0,0.0009701932,0
|
| 586 |
-
0,0.001084977,0
|
| 587 |
-
0,0.0009680819,0
|
| 588 |
-
0,0.23961695,0
|
| 589 |
-
0,0.31963304,0
|
| 590 |
-
0,0.0011128527,0
|
| 591 |
-
0,0.0016477697,0
|
| 592 |
-
0,0.00095384475,0
|
| 593 |
-
0,0.0012485523,0
|
| 594 |
-
0,0.0027906268,0
|
| 595 |
-
0,0.00086827046,0
|
| 596 |
-
0,0.0009598256,0
|
| 597 |
-
1,0.64515626,1
|
| 598 |
-
0,0.0010624658,0
|
| 599 |
-
1,0.43718228,0
|
| 600 |
-
0,0.0008661823,0
|
| 601 |
-
1,0.6091658,1
|
| 602 |
-
0,0.2584575,0
|
| 603 |
-
0,0.000878247,0
|
| 604 |
-
1,0.5297075,1
|
| 605 |
-
0,0.0008958644,0
|
| 606 |
-
0,0.63287455,1
|
| 607 |
-
0,0.8396176,1
|
| 608 |
-
0,0.0017420241,0
|
| 609 |
-
0,0.00093098823,0
|
| 610 |
-
0,0.0008039557,0
|
| 611 |
-
0,0.0008432391,0
|
| 612 |
-
0,0.34551686,0
|
| 613 |
-
0,0.14493409,0
|
| 614 |
-
0,0.001269443,0
|
| 615 |
-
1,0.6718914,1
|
| 616 |
-
0,0.0011318232,0
|
| 617 |
-
1,0.4247725,0
|
| 618 |
-
0,0.0009978332,0
|
| 619 |
-
0,0.0010476196,0
|
| 620 |
-
0,0.0008974574,0
|
| 621 |
-
0,0.42583707,0
|
| 622 |
-
0,0.00087123946,0
|
| 623 |
-
0,0.17696548,0
|
| 624 |
-
1,0.7879036,1
|
| 625 |
-
0,0.000978515,0
|
| 626 |
-
0,0.0009257359,0
|
| 627 |
-
0,0.0012878132,0
|
| 628 |
-
1,0.64483523,1
|
| 629 |
-
0,0.0008505065,0
|
| 630 |
-
1,0.43661386,0
|
| 631 |
-
1,0.57060575,1
|
| 632 |
-
0,0.42568576,0
|
| 633 |
-
0,0.0009259524,0
|
| 634 |
-
0,0.0009263901,0
|
| 635 |
-
0,0.00083254103,0
|
| 636 |
-
0,0.00087859563,0
|
| 637 |
-
0,0.0030385156,0
|
| 638 |
-
1,0.5120378,1
|
| 639 |
-
0,0.0009727936,0
|
| 640 |
-
0,0.0008510578,0
|
| 641 |
-
0,0.0010575671,0
|
| 642 |
-
1,0.6155601,1
|
| 643 |
-
0,0.00091628404,0
|
| 644 |
-
1,0.82486975,1
|
| 645 |
-
0,0.0011844339,0
|
| 646 |
-
0,0.0010148155,0
|
| 647 |
-
1,0.6681816,1
|
| 648 |
-
0,0.0011263781,0
|
| 649 |
-
0,0.61207116,1
|
| 650 |
-
0,0.62184155,1
|
| 651 |
-
1,0.73041785,1
|
| 652 |
-
0,0.0025686398,0
|
| 653 |
-
0,0.0009894907,0
|
| 654 |
-
0,0.0010610846,0
|
| 655 |
-
0,0.0010225485,0
|
| 656 |
-
0,0.001962883,0
|
| 657 |
-
0,0.0010426701,0
|
| 658 |
-
0,0.001386491,0
|
| 659 |
-
0,0.00080189446,0
|
| 660 |
-
0,0.0019367785,0
|
| 661 |
-
0,0.0008910609,0
|
| 662 |
-
0,0.0017919212,0
|
| 663 |
-
0,0.0009899494,0
|
| 664 |
-
0,0.0008948113,0
|
| 665 |
-
0,0.00095926056,0
|
| 666 |
-
0,0.3376383,0
|
| 667 |
-
0,0.0014171846,0
|
| 668 |
-
1,0.659626,1
|
| 669 |
-
0,0.0014162698,0
|
| 670 |
-
0,0.69116306,1
|
| 671 |
-
0,0.0009343347,0
|
| 672 |
-
0,0.00096477,0
|
| 673 |
-
0,0.000974611,0
|
| 674 |
-
1,0.389206,0
|
| 675 |
-
0,0.00087169366,0
|
| 676 |
-
0,0.12066001,0
|
| 677 |
-
0,0.0010589029,0
|
| 678 |
-
1,0.5742767,1
|
| 679 |
-
0,0.0012632777,0
|
| 680 |
-
0,0.0028386211,0
|
| 681 |
-
1,0.7000751,1
|
| 682 |
-
0,0.711954,1
|
| 683 |
-
1,0.2301944,0
|
| 684 |
-
0,0.593717,1
|
| 685 |
-
0,0.0010470055,0
|
| 686 |
-
1,0.20598996,0
|
| 687 |
-
0,0.0011602843,0
|
| 688 |
-
0,0.0009489052,0
|
| 689 |
-
0,0.0009171116,0
|
| 690 |
-
0,0.00077974907,0
|
| 691 |
-
0,0.00092091836,0
|
| 692 |
-
0,0.4514688,0
|
| 693 |
-
0,0.0008686565,0
|
| 694 |
-
1,0.68002725,1
|
| 695 |
-
0,0.0009320532,0
|
| 696 |
-
0,0.00096953986,0
|
| 697 |
-
1,0.47608092,0
|
| 698 |
-
1,0.43994707,0
|
| 699 |
-
0,0.0016875481,0
|
| 700 |
-
0,0.0013200458,0
|
| 701 |
-
1,0.30488643,0
|
| 702 |
-
1,0.2136204,0
|
| 703 |
-
1,0.7080725,1
|
| 704 |
-
0,0.35315382,0
|
| 705 |
-
0,0.0007530385,0
|
| 706 |
-
1,0.3350419,0
|
| 707 |
-
0,0.0010032223,0
|
| 708 |
-
0,0.0010372837,0
|
| 709 |
-
0,0.5047713,1
|
| 710 |
-
0,0.0011856316,0
|
| 711 |
-
1,0.5202941,1
|
| 712 |
-
0,0.036287144,0
|
| 713 |
-
0,0.0015443955,0
|
| 714 |
-
0,0.45689735,0
|
| 715 |
-
0,0.05079241,0
|
| 716 |
-
0,0.00078609615,0
|
| 717 |
-
0,0.00089042104,0
|
| 718 |
-
0,0.00091053615,0
|
| 719 |
-
1,0.8260853,1
|
| 720 |
-
0,0.0012496725,0
|
| 721 |
-
0,0.001003521,0
|
| 722 |
-
0,0.0014080106,0
|
| 723 |
-
0,0.43465498,0
|
| 724 |
-
1,0.7085056,1
|
| 725 |
-
0,0.1071419,0
|
| 726 |
-
0,0.38532647,0
|
| 727 |
-
0,0.0007924066,0
|
| 728 |
-
0,0.0012905765,0
|
| 729 |
-
0,0.38276187,0
|
| 730 |
-
0,0.6617229,1
|
| 731 |
-
1,0.34884775,0
|
| 732 |
-
0,0.0024217672,0
|
| 733 |
-
0,0.0009956957,0
|
| 734 |
-
0,0.25291744,0
|
| 735 |
-
1,0.6034158,1
|
| 736 |
-
0,0.0022521024,0
|
| 737 |
-
0,0.0009386203,0
|
| 738 |
-
1,0.50254047,1
|
| 739 |
-
0,0.00085585134,0
|
| 740 |
-
0,0.59543693,1
|
| 741 |
-
0,0.0011632922,0
|
| 742 |
-
0,0.5392403,1
|
| 743 |
-
1,0.7379359,1
|
| 744 |
-
0,0.615833,1
|
| 745 |
-
0,0.0011334324,0
|
| 746 |
-
1,0.6452454,1
|
| 747 |
-
1,0.5439059,1
|
| 748 |
-
1,0.4070706,0
|
| 749 |
-
0,0.00085760804,0
|
| 750 |
-
0,0.0013209702,0
|
| 751 |
-
0,0.00088978535,0
|
| 752 |
-
0,0.17897561,0
|
| 753 |
-
0,0.0008940443,0
|
| 754 |
-
0,0.0021460964,0
|
| 755 |
-
0,0.46505737,0
|
| 756 |
-
0,0.00095002423,0
|
| 757 |
-
0,0.0011486006,0
|
| 758 |
-
0,0.0008134391,0
|
| 759 |
-
1,0.61036927,1
|
| 760 |
-
1,0.82847095,1
|
| 761 |
-
0,0.0008939127,0
|
| 762 |
-
1,0.74133915,1
|
| 763 |
-
0,0.0010881014,0
|
| 764 |
-
1,0.6198983,1
|
| 765 |
-
0,0.00094616745,0
|
| 766 |
-
1,0.37103793,0
|
| 767 |
-
0,0.0011978437,0
|
| 768 |
-
0,0.21946022,0
|
| 769 |
-
0,0.0010989493,0
|
| 770 |
-
0,0.0011152511,0
|
| 771 |
-
1,0.6064778,1
|
| 772 |
-
0,0.0021220825,0
|
| 773 |
-
0,0.001013543,0
|
| 774 |
-
0,0.00082881283,0
|
| 775 |
-
0,0.0016203971,0
|
| 776 |
-
0,0.71427095,1
|
| 777 |
-
0,0.5296034,1
|
| 778 |
-
0,0.0007731554,0
|
| 779 |
-
0,0.0011991286,0
|
| 780 |
-
0,0.0014669152,0
|
| 781 |
-
0,0.56059104,1
|
| 782 |
-
0,0.0009583868,0
|
| 783 |
-
1,0.83952165,1
|
| 784 |
-
0,0.22009522,0
|
| 785 |
-
0,0.001713881,0
|
| 786 |
-
0,0.0011007866,0
|
| 787 |
-
0,0.22656342,0
|
| 788 |
-
0,0.0026020007,0
|
| 789 |
-
0,0.0018667651,0
|
| 790 |
-
0,0.5471779,1
|
| 791 |
-
0,0.001456346,0
|
| 792 |
-
0,0.0010325527,0
|
| 793 |
-
0,0.20448126,0
|
| 794 |
-
0,0.001224137,0
|
| 795 |
-
1,0.6375793,1
|
| 796 |
-
0,0.2193653,0
|
| 797 |
-
0,0.0008086656,0
|
| 798 |
-
0,0.0012699576,0
|
| 799 |
-
0,0.0008365637,0
|
| 800 |
-
0,0.57084,1
|
| 801 |
-
0,0.0010449449,0
|
| 802 |
-
1,0.20258097,0
|
| 803 |
-
0,0.0010269393,0
|
| 804 |
-
1,0.4684375,0
|
| 805 |
-
0,0.0009493274,0
|
| 806 |
-
0,0.0009980378,0
|
| 807 |
-
0,0.0009156465,0
|
| 808 |
-
0,0.6496492,1
|
| 809 |
-
0,0.34035468,0
|
| 810 |
-
0,0.0010597436,0
|
| 811 |
-
1,0.47263625,0
|
| 812 |
-
0,0.0009759152,0
|
| 813 |
-
1,0.78246176,1
|
| 814 |
-
0,0.4524046,0
|
| 815 |
-
1,0.38766143,0
|
| 816 |
-
1,0.7272198,1
|
| 817 |
-
0,0.0008958892,0
|
| 818 |
-
0,0.0010322925,0
|
| 819 |
-
0,0.0010094311,0
|
| 820 |
-
0,0.0012807564,0
|
| 821 |
-
0,0.5271925,1
|
| 822 |
-
0,0.36517152,0
|
| 823 |
-
0,0.10721343,0
|
| 824 |
-
0,0.0009403063,0
|
| 825 |
-
0,0.00077919016,0
|
| 826 |
-
0,0.0009822751,0
|
| 827 |
-
0,0.00076234364,0
|
| 828 |
-
0,0.00078162784,0
|
| 829 |
-
1,0.5729158,1
|
| 830 |
-
0,0.00089957967,0
|
| 831 |
-
0,0.0009848445,0
|
| 832 |
-
1,0.551513,1
|
| 833 |
-
1,0.7289652,1
|
| 834 |
-
0,0.0010857919,0
|
| 835 |
-
1,0.5924267,1
|
| 836 |
-
0,0.0015767976,0
|
| 837 |
-
0,0.0008469038,0
|
| 838 |
-
1,0.90326667,1
|
| 839 |
-
0,0.0009163141,0
|
| 840 |
-
0,0.0011933714,0
|
| 841 |
-
0,0.0011287286,0
|
| 842 |
-
0,0.0008222916,0
|
| 843 |
-
0,0.008933486,0
|
| 844 |
-
0,0.0018526828,0
|
| 845 |
-
0,0.00112532,0
|
| 846 |
-
0,0.08725183,0
|
| 847 |
-
0,0.0011224038,0
|
| 848 |
-
0,0.0024273458,0
|
| 849 |
-
0,0.0020511525,0
|
| 850 |
-
0,0.0021996044,0
|
| 851 |
-
1,0.7084392,1
|
| 852 |
-
0,0.0013229746,0
|
| 853 |
-
0,0.017357908,0
|
| 854 |
-
0,0.002221878,0
|
| 855 |
-
0,0.00085381826,0
|
| 856 |
-
0,0.0045482507,0
|
| 857 |
-
0,0.27093408,0
|
| 858 |
-
0,0.50731814,1
|
| 859 |
-
0,0.0012390758,0
|
| 860 |
-
1,0.56386113,1
|
| 861 |
-
0,0.00089642877,0
|
| 862 |
-
0,0.0009294908,0
|
| 863 |
-
0,0.0012274805,0
|
| 864 |
-
0,0.0009345854,0
|
| 865 |
-
0,0.00086570746,0
|
| 866 |
-
0,0.0011756949,0
|
| 867 |
-
0,0.38236865,0
|
| 868 |
-
1,0.687024,1
|
| 869 |
-
0,0.23687929,0
|
| 870 |
-
0,0.0016748687,0
|
| 871 |
-
1,0.6843725,1
|
| 872 |
-
0,0.304691,0
|
| 873 |
-
0,0.0009893903,0
|
| 874 |
-
0,0.0012283546,0
|
| 875 |
-
0,0.0009146128,0
|
| 876 |
-
0,0.7573988,1
|
| 877 |
-
1,0.34269,0
|
| 878 |
-
0,0.0034435734,0
|
| 879 |
-
0,0.00087736914,0
|
| 880 |
-
0,0.0015209204,0
|
| 881 |
-
0,0.0009608211,0
|
| 882 |
-
0,0.0009280081,0
|
| 883 |
-
0,0.0012842439,0
|
| 884 |
-
1,0.6248447,1
|
| 885 |
-
1,0.62274516,1
|
| 886 |
-
0,0.00165578,0
|
| 887 |
-
0,0.0008775915,0
|
| 888 |
-
0,0.18800546,0
|
| 889 |
-
1,0.777882,1
|
| 890 |
-
1,0.76358753,1
|
| 891 |
-
0,0.00093880715,0
|
| 892 |
-
0,0.7536075,1
|
| 893 |
-
0,0.3016229,0
|
| 894 |
-
0,0.0011477552,0
|
| 895 |
-
0,0.0014952337,0
|
| 896 |
-
0,0.0009555025,0
|
| 897 |
-
0,0.15660593,0
|
| 898 |
-
0,0.0011227432,0
|
| 899 |
-
0,0.001997399,0
|
| 900 |
-
0,0.56990355,1
|
| 901 |
-
0,0.3734899,0
|
| 902 |
-
1,0.5575483,1
|
| 903 |
-
1,0.6860012,1
|
| 904 |
-
1,0.10437922,0
|
| 905 |
-
1,0.8180956,1
|
| 906 |
-
0,0.0011188495,0
|
| 907 |
-
0,0.0010619572,0
|
| 908 |
-
0,0.56458724,1
|
| 909 |
-
0,0.0010933846,0
|
| 910 |
-
0,0.0012181381,0
|
| 911 |
-
0,0.4674541,0
|
| 912 |
-
1,0.485787,0
|
| 913 |
-
0,0.0013055103,0
|
| 914 |
-
1,0.70587736,1
|
| 915 |
-
0,0.00079359673,0
|
| 916 |
-
1,0.80720824,1
|
| 917 |
-
0,0.0016512059,0
|
| 918 |
-
1,0.7673012,1
|
| 919 |
-
0,0.0010114614,0
|
| 920 |
-
0,0.0013267859,0
|
| 921 |
-
0,0.0008793849,0
|
| 922 |
-
0,0.0013299336,0
|
| 923 |
-
0,0.08014288,0
|
| 924 |
-
0,0.0035911172,0
|
| 925 |
-
0,0.0009812173,0
|
| 926 |
-
1,0.25576597,0
|
| 927 |
-
1,0.7145252,1
|
| 928 |
-
1,0.74173844,1
|
| 929 |
-
1,0.7591139,1
|
| 930 |
-
1,0.79900056,1
|
| 931 |
-
0,0.0010525746,0
|
| 932 |
-
0,0.0010317984,0
|
| 933 |
-
0,0.0009139964,0
|
| 934 |
-
0,0.0011968515,0
|
| 935 |
-
0,0.0010743748,0
|
| 936 |
-
0,0.46601886,0
|
| 937 |
-
0,0.0010757077,0
|
| 938 |
-
0,0.00852641,0
|
| 939 |
-
0,0.002048128,0
|
| 940 |
-
1,0.80355567,1
|
| 941 |
-
0,0.0014045321,0
|
| 942 |
-
0,0.0011115061,0
|
| 943 |
-
0,0.0008338689,0
|
| 944 |
-
1,0.5675804,1
|
| 945 |
-
0,0.0008722262,0
|
| 946 |
-
0,0.001032084,0
|
| 947 |
-
0,0.00090455357,0
|
| 948 |
-
0,0.000761143,0
|
| 949 |
-
1,0.25547284,0
|
| 950 |
-
0,0.72657347,1
|
| 951 |
-
0,0.00095886085,0
|
| 952 |
-
0,0.0007715649,0
|
| 953 |
-
0,0.00086148566,0
|
| 954 |
-
0,0.0017338935,0
|
| 955 |
-
0,0.0029073667,0
|
| 956 |
-
1,0.74649566,1
|
| 957 |
-
0,0.00082679826,0
|
| 958 |
-
0,0.00097232254,0
|
| 959 |
-
0,0.001773119,0
|
| 960 |
-
0,0.0009934949,0
|
| 961 |
-
1,0.8185504,1
|
| 962 |
-
0,0.0017876541,0
|
| 963 |
-
0,0.0008482355,0
|
| 964 |
-
0,0.00088651665,0
|
| 965 |
-
0,0.0028928719,0
|
| 966 |
-
0,0.000960697,0
|
| 967 |
-
0,0.6135191,1
|
| 968 |
-
0,0.0015025416,0
|
| 969 |
-
0,0.00096636615,0
|
| 970 |
-
1,0.27818596,0
|
| 971 |
-
0,0.0010354616,0
|
| 972 |
-
1,0.7740497,1
|
| 973 |
-
0,0.0010154156,0
|
| 974 |
-
0,0.0010276875,0
|
| 975 |
-
0,0.0013553945,0
|
| 976 |
-
0,0.64635444,1
|
| 977 |
-
0,0.36508304,0
|
| 978 |
-
0,0.001062856,0
|
| 979 |
-
0,0.00074448035,0
|
| 980 |
-
0,0.0013278496,0
|
| 981 |
-
0,0.43072903,0
|
| 982 |
-
0,0.3727711,0
|
| 983 |
-
0,0.1476106,0
|
| 984 |
-
1,0.7469999,1
|
| 985 |
-
1,0.5393194,1
|
| 986 |
-
1,0.62362605,1
|
| 987 |
-
0,0.4977242,0
|
| 988 |
-
0,0.0019504097,0
|
| 989 |
-
0,0.0011121453,0
|
| 990 |
-
0,0.00094129675,0
|
| 991 |
-
1,0.13999684,0
|
| 992 |
-
1,0.84094584,1
|
| 993 |
-
1,0.67023414,1
|
| 994 |
-
0,0.0009355351,0
|
| 995 |
-
0,0.0012011972,0
|
| 996 |
-
1,0.85856575,1
|
| 997 |
-
0,0.00095948775,0
|
| 998 |
-
0,0.00092585845,0
|
| 999 |
-
1,0.78414935,1
|
| 1000 |
-
0,0.2081838,0
|
| 1001 |
-
1,0.3134156,0
|
| 1002 |
-
0,0.5827518,1
|
| 1003 |
-
0,0.0011884035,0
|
| 1004 |
-
0,0.3416024,0
|
| 1005 |
-
0,0.3513188,0
|
| 1006 |
-
1,0.7360253,1
|
| 1007 |
-
0,0.14421782,0
|
| 1008 |
-
0,0.0008775974,0
|
| 1009 |
-
0,0.0010155869,0
|
| 1010 |
-
1,0.6316725,1
|
| 1011 |
-
0,0.0011462267,0
|
| 1012 |
-
0,0.00090081355,0
|
| 1013 |
-
0,0.46972373,0
|
| 1014 |
-
0,0.0010553041,0
|
| 1015 |
-
0,0.5809954,1
|
| 1016 |
-
0,0.0059841666,0
|
| 1017 |
-
1,0.6962589,1
|
| 1018 |
-
0,0.004652636,0
|
| 1019 |
-
0,0.0008641569,0
|
| 1020 |
-
1,0.80275476,1
|
| 1021 |
-
0,0.0008407392,0
|
| 1022 |
-
0,0.0010207775,0
|
| 1023 |
-
0,0.0011965961,0
|
| 1024 |
-
1,0.41581193,0
|
| 1025 |
-
0,0.002618945,0
|
| 1026 |
-
0,0.001120625,0
|
| 1027 |
-
0,0.00090287067,0
|
| 1028 |
-
1,0.6921951,1
|
| 1029 |
-
0,0.32014215,0
|
| 1030 |
-
1,0.44863924,0
|
| 1031 |
-
0,0.20511761,0
|
| 1032 |
-
0,0.0008183238,0
|
| 1033 |
-
0,0.27403337,0
|
| 1034 |
-
0,0.0014313993,0
|
| 1035 |
-
0,0.52280444,1
|
| 1036 |
-
0,0.00095016713,0
|
| 1037 |
-
0,0.001587931,0
|
| 1038 |
-
0,0.2349104,0
|
| 1039 |
-
1,0.65222037,1
|
| 1040 |
-
0,0.0015776413,0
|
| 1041 |
-
0,0.0012768807,0
|
| 1042 |
-
0,0.0011938995,0
|
| 1043 |
-
1,0.8248594,1
|
| 1044 |
-
1,0.5302688,1
|
| 1045 |
-
0,0.0017132758,0
|
| 1046 |
-
0,0.0008661427,0
|
| 1047 |
-
0,0.0013237718,0
|
| 1048 |
-
0,0.30642387,0
|
| 1049 |
-
0,0.0013227944,0
|
| 1050 |
-
0,0.0007931251,0
|
| 1051 |
-
0,0.000937015,0
|
| 1052 |
-
0,0.0011775857,0
|
| 1053 |
-
0,0.0031580946,0
|
| 1054 |
-
0,0.0017192552,0
|
| 1055 |
-
0,0.00090431585,0
|
| 1056 |
-
0,0.0008876776,0
|
| 1057 |
-
0,0.3237665,0
|
| 1058 |
-
0,0.0011375708,0
|
| 1059 |
-
0,0.713965,1
|
| 1060 |
-
1,0.6401749,1
|
| 1061 |
-
0,0.055340905,0
|
| 1062 |
-
0,0.0008713582,0
|
| 1063 |
-
0,0.0011798014,0
|
| 1064 |
-
0,0.19061498,0
|
| 1065 |
-
0,0.0008910244,0
|
| 1066 |
-
1,0.6702638,1
|
| 1067 |
-
0,0.00078732,0
|
| 1068 |
-
1,0.5069871,1
|
| 1069 |
-
0,0.0010917349,0
|
| 1070 |
-
1,0.30569232,0
|
| 1071 |
-
0,0.0010508688,0
|
| 1072 |
-
0,0.008350478,0
|
| 1073 |
-
0,0.1295904,0
|
| 1074 |
-
0,0.0012138723,0
|
| 1075 |
-
0,0.23459788,0
|
| 1076 |
-
0,0.00095988956,0
|
| 1077 |
-
0,0.0011302889,0
|
| 1078 |
-
0,0.41533685,0
|
| 1079 |
-
1,0.81494963,1
|
| 1080 |
-
1,0.20318276,0
|
| 1081 |
-
1,0.6849259,1
|
| 1082 |
-
0,0.0007971484,0
|
| 1083 |
-
0,0.0010026495,0
|
| 1084 |
-
1,0.53866345,1
|
| 1085 |
-
0,0.47900033,0
|
| 1086 |
-
0,0.0013153972,0
|
| 1087 |
-
0,0.00092683197,0
|
| 1088 |
-
0,0.00079303206,0
|
| 1089 |
-
0,0.0010194737,0
|
| 1090 |
-
0,0.00090722676,0
|
| 1091 |
-
0,0.0014398387,0
|
| 1092 |
-
0,0.0008348118,0
|
| 1093 |
-
1,0.8690826,1
|
| 1094 |
-
0,0.45386383,0
|
| 1095 |
-
0,0.0009348062,0
|
| 1096 |
-
1,0.7537117,1
|
| 1097 |
-
0,0.0009848943,0
|
| 1098 |
-
1,0.8337264,1
|
| 1099 |
-
0,0.0013707685,0
|
| 1100 |
-
1,0.64028686,1
|
| 1101 |
-
0,0.0009233346,0
|
| 1102 |
-
0,0.0011110144,0
|
| 1103 |
-
0,0.00091898517,0
|
| 1104 |
-
0,0.37525868,0
|
| 1105 |
-
0,0.0012479519,0
|
| 1106 |
-
0,0.0011017517,0
|
| 1107 |
-
0,0.23442647,0
|
| 1108 |
-
0,0.001286486,0
|
| 1109 |
-
0,0.0010251929,0
|
| 1110 |
-
0,0.0012779953,0
|
| 1111 |
-
0,0.10511962,0
|
| 1112 |
-
0,0.0013187885,0
|
| 1113 |
-
0,0.0008432262,0
|
| 1114 |
-
0,0.0023270152,0
|
| 1115 |
-
0,0.19635822,0
|
| 1116 |
-
0,0.42605725,0
|
| 1117 |
-
0,0.0026588596,0
|
| 1118 |
-
0,0.0010699191,0
|
| 1119 |
-
0,0.00083928637,0
|
| 1120 |
-
0,0.0010421542,0
|
| 1121 |
-
0,0.0010161884,0
|
| 1122 |
-
0,0.0007994666,0
|
| 1123 |
-
0,0.0011844927,0
|
| 1124 |
-
1,0.51526463,1
|
| 1125 |
-
1,0.8424703,1
|
| 1126 |
-
0,0.0008336365,0
|
| 1127 |
-
0,0.0008402345,0
|
| 1128 |
-
0,0.0010367532,0
|
| 1129 |
-
0,0.0008751524,0
|
| 1130 |
-
0,0.0013134012,0
|
| 1131 |
-
0,0.0008601877,0
|
| 1132 |
-
0,0.28373632,0
|
| 1133 |
-
0,0.0010632672,0
|
| 1134 |
-
0,0.001001677,0
|
| 1135 |
-
0,0.0009289228,0
|
| 1136 |
-
0,0.0010856837,0
|
| 1137 |
-
0,0.0009986022,0
|
| 1138 |
-
0,0.0011016888,0
|
| 1139 |
-
0,0.0009970806,0
|
| 1140 |
-
0,0.0014796017,0
|
| 1141 |
-
0,0.0009895341,0
|
| 1142 |
-
1,0.87620765,1
|
| 1143 |
-
0,0.13668354,0
|
| 1144 |
-
0,0.5543678,1
|
| 1145 |
-
0,0.0011205432,0
|
| 1146 |
-
1,0.60824937,1
|
| 1147 |
-
1,0.8209723,1
|
| 1148 |
-
1,0.77438575,1
|
| 1149 |
-
0,0.44236442,0
|
| 1150 |
-
1,0.5708855,1
|
| 1151 |
-
0,0.0014770482,0
|
| 1152 |
-
0,0.0011481309,0
|
| 1153 |
-
1,0.76328576,1
|
| 1154 |
-
0,0.0012447835,0
|
| 1155 |
-
0,0.6776131,1
|
| 1156 |
-
0,0.22794476,0
|
| 1157 |
-
0,0.0015239945,0
|
| 1158 |
-
0,0.0014258837,0
|
| 1159 |
-
0,0.2854073,0
|
| 1160 |
-
0,0.0008883935,0
|
| 1161 |
-
1,0.57807,1
|
| 1162 |
-
0,0.0011267756,0
|
| 1163 |
-
1,0.61485714,1
|
| 1164 |
-
0,0.0007958319,0
|
| 1165 |
-
0,0.0010823915,0
|
| 1166 |
-
0,0.2250206,0
|
| 1167 |
-
1,0.4204774,0
|
| 1168 |
-
0,0.00093722483,0
|
| 1169 |
-
0,0.0017883043,0
|
| 1170 |
-
0,0.0009139735,0
|
| 1171 |
-
0,0.00089569465,0
|
| 1172 |
-
0,0.0013786783,0
|
| 1173 |
-
0,0.63649786,1
|
| 1174 |
-
0,0.00097638293,0
|
| 1175 |
-
1,0.73162854,1
|
| 1176 |
-
0,0.00088065065,0
|
| 1177 |
-
0,0.0010577292,0
|
| 1178 |
-
0,0.0011432399,0
|
| 1179 |
-
1,0.6012633,1
|
| 1180 |
-
0,0.0011757269,0
|
| 1181 |
-
0,0.0008843888,0
|
| 1182 |
-
0,0.00094909506,0
|
| 1183 |
-
0,0.0009065131,0
|
| 1184 |
-
1,0.50341845,1
|
| 1185 |
-
1,0.3092106,0
|
| 1186 |
-
0,0.087408595,0
|
| 1187 |
-
1,0.47898576,0
|
| 1188 |
-
0,0.7805265,1
|
| 1189 |
-
0,0.55192417,1
|
| 1190 |
-
0,0.100362,0
|
| 1191 |
-
0,0.0009126828,0
|
| 1192 |
-
1,0.7031464,1
|
| 1193 |
-
0,0.0008378504,0
|
| 1194 |
-
0,0.000909575,0
|
| 1195 |
-
0,0.14698155,0
|
| 1196 |
-
0,0.000886812,0
|
| 1197 |
-
0,0.0034394178,0
|
| 1198 |
-
0,0.00083480857,0
|
| 1199 |
-
0,0.0009230623,0
|
| 1200 |
-
0,0.00086228485,0
|
| 1201 |
-
0,0.0012248036,0
|
| 1202 |
-
0,0.3685421,0
|
| 1203 |
-
1,0.75045246,1
|
| 1204 |
-
1,0.29140982,0
|
| 1205 |
-
1,0.824118,1
|
| 1206 |
-
0,0.0011165916,0
|
| 1207 |
-
0,0.0015328506,0
|
| 1208 |
-
0,0.49720347,0
|
| 1209 |
-
1,0.6922905,1
|
| 1210 |
-
0,0.0007836441,0
|
| 1211 |
-
0,0.000927602,0
|
| 1212 |
-
0,0.70492715,1
|
| 1213 |
-
0,0.0008786006,0
|
| 1214 |
-
0,0.7033895,1
|
| 1215 |
-
1,0.6138756,1
|
| 1216 |
-
1,0.5470745,1
|
| 1217 |
-
1,0.28883335,0
|
| 1218 |
-
0,0.001316554,0
|
| 1219 |
-
0,0.000900596,0
|
| 1220 |
-
1,0.77885973,1
|
| 1221 |
-
0,0.3792245,0
|
| 1222 |
-
0,0.0010777283,0
|
| 1223 |
-
0,0.022616526,0
|
| 1224 |
-
0,0.3294621,0
|
| 1225 |
-
0,0.00081920385,0
|
| 1226 |
-
0,0.0011605065,0
|
| 1227 |
-
0,0.2769352,0
|
| 1228 |
-
0,0.80774456,1
|
| 1229 |
-
0,0.000816923,0
|
| 1230 |
-
0,0.0010059879,0
|
| 1231 |
-
0,0.0009761908,0
|
| 1232 |
-
0,0.0019736588,0
|
| 1233 |
-
0,0.001108255,0
|
| 1234 |
-
0,0.00084299047,0
|
| 1235 |
-
0,0.0007932735,0
|
| 1236 |
-
0,0.0010224554,0
|
| 1237 |
-
0,0.0010019916,0
|
| 1238 |
-
0,0.0011584796,0
|
| 1239 |
-
0,0.00095932867,0
|
| 1240 |
-
0,0.0014852965,0
|
| 1241 |
-
0,0.0011517362,0
|
| 1242 |
-
0,0.0016003981,0
|
| 1243 |
-
0,0.5883149,1
|
| 1244 |
-
0,0.0011397423,0
|
| 1245 |
-
0,0.54703045,1
|
| 1246 |
-
0,0.46631202,0
|
| 1247 |
-
1,0.72806704,1
|
| 1248 |
-
1,0.66733813,1
|
| 1249 |
-
0,0.00082930306,0
|
| 1250 |
-
0,0.0013221847,0
|
| 1251 |
-
0,0.37714672,0
|
| 1252 |
-
1,0.6671186,1
|
| 1253 |
-
0,0.76171786,1
|
| 1254 |
-
1,0.84557354,1
|
| 1255 |
-
0,0.0009865002,0
|
| 1256 |
-
0,0.00078149413,0
|
| 1257 |
-
0,0.0020528194,0
|
| 1258 |
-
0,0.001968213,0
|
| 1259 |
-
1,0.29894271,0
|
| 1260 |
-
1,0.65170336,1
|
| 1261 |
-
0,0.00087412616,0
|
| 1262 |
-
0,0.0008334153,0
|
| 1263 |
-
0,0.002001055,0
|
| 1264 |
-
0,0.0010972196,0
|
| 1265 |
-
1,0.6604654,1
|
| 1266 |
-
1,0.75812054,1
|
| 1267 |
-
1,0.69461435,1
|
| 1268 |
-
0,0.20077878,0
|
| 1269 |
-
0,0.19034809,0
|
| 1270 |
-
0,0.001009536,0
|
| 1271 |
-
0,0.09553723,0
|
| 1272 |
-
1,0.4960136,0
|
| 1273 |
-
0,0.0012337598,0
|
| 1274 |
-
0,0.0030067663,0
|
| 1275 |
-
0,0.53967786,1
|
| 1276 |
-
0,0.0012488324,0
|
| 1277 |
-
0,0.001111368,0
|
| 1278 |
-
0,0.001234995,0
|
| 1279 |
-
0,0.0011818678,0
|
| 1280 |
-
0,0.42678392,0
|
| 1281 |
-
1,0.73771423,1
|
| 1282 |
-
0,0.7895419,1
|
| 1283 |
-
0,0.0012994151,0
|
| 1284 |
-
0,0.00084598584,0
|
| 1285 |
-
0,0.0007987745,0
|
| 1286 |
-
0,0.0012115806,0
|
| 1287 |
-
1,0.58688515,1
|
| 1288 |
-
0,0.00091692246,0
|
| 1289 |
-
0,0.0010023817,0
|
| 1290 |
-
1,0.7862743,1
|
| 1291 |
-
1,0.25347495,0
|
| 1292 |
-
0,0.3326281,0
|
| 1293 |
-
0,0.3226471,0
|
| 1294 |
-
0,0.6390405,1
|
| 1295 |
-
0,0.54111165,1
|
| 1296 |
-
0,0.001012588,0
|
| 1297 |
-
0,0.0009537402,0
|
| 1298 |
-
0,0.0008496503,0
|
| 1299 |
-
1,0.6967211,1
|
| 1300 |
-
0,0.0010966522,0
|
| 1301 |
-
0,0.0011798083,0
|
| 1302 |
-
0,0.6963768,1
|
| 1303 |
-
0,0.00095325924,0
|
| 1304 |
-
0,0.0011999392,0
|
| 1305 |
-
0,0.32601213,0
|
| 1306 |
-
1,0.5316113,1
|
| 1307 |
-
0,0.46846902,0
|
| 1308 |
-
0,0.00096525403,0
|
| 1309 |
-
0,0.0011292063,0
|
| 1310 |
-
0,0.001140362,0
|
| 1311 |
-
0,0.0012218289,0
|
| 1312 |
-
0,0.40930805,0
|
| 1313 |
-
0,0.0011080602,0
|
| 1314 |
-
1,0.5122479,1
|
| 1315 |
-
1,0.59998643,1
|
| 1316 |
-
0,0.061730716,0
|
| 1317 |
-
1,0.41931114,0
|
| 1318 |
-
0,0.16193198,0
|
| 1319 |
-
0,0.0007972308,0
|
| 1320 |
-
0,0.0009665351,0
|
| 1321 |
-
0,0.0010098865,0
|
| 1322 |
-
0,0.0020258357,0
|
| 1323 |
-
0,0.46376464,0
|
| 1324 |
-
0,0.28660434,0
|
| 1325 |
-
0,0.0010756479,0
|
| 1326 |
-
0,0.5611161,1
|
| 1327 |
-
0,0.1110538,0
|
| 1328 |
-
0,0.0015265195,0
|
| 1329 |
-
0,0.00088010135,0
|
| 1330 |
-
0,0.23574975,0
|
| 1331 |
-
0,0.42993984,0
|
| 1332 |
-
0,0.0012539547,0
|
| 1333 |
-
0,0.39793822,0
|
| 1334 |
-
0,0.0008850873,0
|
| 1335 |
-
0,0.0010175374,0
|
| 1336 |
-
0,0.57915115,1
|
| 1337 |
-
0,0.60442656,1
|
| 1338 |
-
0,0.0008875359,0
|
| 1339 |
-
0,0.0010697005,0
|
| 1340 |
-
0,0.0009562333,0
|
| 1341 |
-
0,0.0019253149,0
|
| 1342 |
-
1,0.77746564,1
|
| 1343 |
-
0,0.0015686869,0
|
| 1344 |
-
1,0.31335396,0
|
| 1345 |
-
0,0.000908781,0
|
| 1346 |
-
0,0.0013533514,0
|
| 1347 |
-
0,0.2515571,0
|
| 1348 |
-
0,0.00089592073,0
|
| 1349 |
-
1,0.25762314,0
|
| 1350 |
-
1,0.6580362,1
|
| 1351 |
-
0,0.000960752,0
|
| 1352 |
-
0,0.0010607035,0
|
| 1353 |
-
0,0.0077122697,0
|
| 1354 |
-
1,0.6530093,1
|
| 1355 |
-
0,0.00087614363,0
|
| 1356 |
-
0,0.007246284,0
|
| 1357 |
-
0,0.6812245,1
|
| 1358 |
-
0,0.23485817,0
|
| 1359 |
-
0,0.00086641597,0
|
| 1360 |
-
0,0.0012842317,0
|
| 1361 |
-
0,0.0007809932,0
|
| 1362 |
-
1,0.5749303,1
|
| 1363 |
-
0,0.0015573403,0
|
| 1364 |
-
0,0.13056204,0
|
| 1365 |
-
0,0.357255,0
|
| 1366 |
-
0,0.34036377,0
|
| 1367 |
-
0,0.3980148,0
|
| 1368 |
-
1,0.69733506,1
|
| 1369 |
-
0,0.25944367,0
|
| 1370 |
-
0,0.34873602,0
|
| 1371 |
-
0,0.0009726848,0
|
| 1372 |
-
1,0.49505424,0
|
| 1373 |
-
1,0.6892435,1
|
| 1374 |
-
1,0.7743485,1
|
| 1375 |
-
0,0.32794416,0
|
| 1376 |
-
1,0.22797777,0
|
| 1377 |
-
0,0.0009458529,0
|
| 1378 |
-
0,0.38203833,0
|
| 1379 |
-
0,0.0012890919,0
|
| 1380 |
-
0,0.000789744,0
|
| 1381 |
-
0,0.31899726,0
|
| 1382 |
-
0,0.6728214,1
|
| 1383 |
-
0,0.000999975,0
|
| 1384 |
-
0,0.00094006164,0
|
| 1385 |
-
0,0.00075543154,0
|
| 1386 |
-
0,0.0013109997,0
|
| 1387 |
-
0,0.26060212,0
|
| 1388 |
-
0,0.2962978,0
|
| 1389 |
-
1,0.81047857,1
|
| 1390 |
-
0,0.0009169193,0
|
| 1391 |
-
0,0.0009614863,0
|
| 1392 |
-
0,0.00097122806,0
|
| 1393 |
-
0,0.0015418291,0
|
| 1394 |
-
0,0.0010950873,0
|
| 1395 |
-
0,0.0043601934,0
|
| 1396 |
-
1,0.6357652,1
|
| 1397 |
-
0,0.001165279,0
|
| 1398 |
-
0,0.0009673975,0
|
| 1399 |
-
1,0.83262265,1
|
| 1400 |
-
0,0.0009282163,0
|
| 1401 |
-
0,0.70172244,1
|
| 1402 |
-
0,0.0009607166,0
|
| 1403 |
-
0,0.00084028306,0
|
| 1404 |
-
0,0.6203581,1
|
| 1405 |
-
0,0.2677468,0
|
| 1406 |
-
1,0.7861525,1
|
| 1407 |
-
1,0.76065445,1
|
| 1408 |
-
0,0.0022763049,0
|
| 1409 |
-
0,0.0008973102,0
|
| 1410 |
-
0,0.0008466766,0
|
| 1411 |
-
0,0.0009204096,0
|
| 1412 |
-
1,0.5665215,1
|
| 1413 |
-
0,0.50933987,1
|
| 1414 |
-
0,0.0011783128,0
|
| 1415 |
-
0,0.00090577826,0
|
| 1416 |
-
0,0.47160682,0
|
| 1417 |
-
1,0.86713326,1
|
| 1418 |
-
0,0.001341303,0
|
| 1419 |
-
1,0.16438311,0
|
| 1420 |
-
1,0.5899627,1
|
| 1421 |
-
0,0.0009171007,0
|
| 1422 |
-
0,0.0009108268,0
|
| 1423 |
-
0,0.00084698235,0
|
| 1424 |
-
0,0.0007781086,0
|
| 1425 |
-
0,0.0010233963,0
|
| 1426 |
-
0,0.24456206,0
|
| 1427 |
-
0,0.0014641931,0
|
| 1428 |
-
0,0.0007724079,0
|
| 1429 |
-
0,0.35753417,0
|
| 1430 |
-
0,0.000865038,0
|
| 1431 |
-
0,0.0031078586,0
|
| 1432 |
-
0,0.0008967957,0
|
| 1433 |
-
1,0.77871156,1
|
| 1434 |
-
1,0.58639294,1
|
| 1435 |
-
0,0.0026566426,0
|
| 1436 |
-
0,0.00082757784,0
|
| 1437 |
-
0,0.0011661336,0
|
| 1438 |
-
0,0.001019244,0
|
| 1439 |
-
1,0.7419938,1
|
| 1440 |
-
1,0.42898682,0
|
| 1441 |
-
1,0.76121277,1
|
| 1442 |
-
0,0.0012546297,0
|
| 1443 |
-
1,0.37173283,0
|
| 1444 |
-
1,0.60157835,1
|
| 1445 |
-
1,0.3550261,0
|
| 1446 |
-
0,0.00083142443,0
|
| 1447 |
-
0,0.010433526,0
|
| 1448 |
-
0,0.0008385298,0
|
| 1449 |
-
0,0.0010685887,0
|
| 1450 |
-
0,0.0014192335,0
|
| 1451 |
-
0,0.00084556115,0
|
| 1452 |
-
0,0.25132394,0
|
| 1453 |
-
0,0.00080387027,0
|
| 1454 |
-
0,0.0009537752,0
|
| 1455 |
-
1,0.77484345,1
|
| 1456 |
-
0,0.0010258998,0
|
| 1457 |
-
0,0.00085469737,0
|
| 1458 |
-
1,0.80244,1
|
| 1459 |
-
0,0.6002382,1
|
| 1460 |
-
0,0.00085503253,0
|
| 1461 |
-
0,0.00087335554,0
|
| 1462 |
-
0,0.1733697,0
|
| 1463 |
-
0,0.0013072107,0
|
| 1464 |
-
0,0.31264597,0
|
| 1465 |
-
0,0.82034665,1
|
| 1466 |
-
1,0.67145926,1
|
| 1467 |
-
1,0.5704474,1
|
| 1468 |
-
0,0.0019223319,0
|
| 1469 |
-
0,0.0009967473,0
|
| 1470 |
-
0,0.4709534,0
|
| 1471 |
-
0,0.0044812486,0
|
| 1472 |
-
0,0.0009170116,0
|
| 1473 |
-
1,0.7392076,1
|
| 1474 |
-
0,0.0061432202,0
|
| 1475 |
-
0,0.45998988,0
|
| 1476 |
-
0,0.0033909979,0
|
| 1477 |
-
0,0.0011331846,0
|
| 1478 |
-
0,0.0010110885,0
|
| 1479 |
-
0,0.0011058631,0
|
| 1480 |
-
1,0.41766933,0
|
| 1481 |
-
0,0.00071926217,0
|
| 1482 |
-
0,0.54065317,1
|
| 1483 |
-
1,0.34456867,0
|
| 1484 |
-
0,0.0010258704,0
|
| 1485 |
-
0,0.0016063552,0
|
| 1486 |
-
0,0.3136817,0
|
| 1487 |
-
0,0.0009535526,0
|
| 1488 |
-
0,0.0010021557,0
|
| 1489 |
-
0,0.0009705446,0
|
| 1490 |
-
0,0.14366324,0
|
| 1491 |
-
0,0.11577339,0
|
| 1492 |
-
1,0.5175177,1
|
| 1493 |
-
0,0.0010051089,0
|
| 1494 |
-
0,0.0010658188,0
|
| 1495 |
-
0,0.00097995,0
|
| 1496 |
-
0,0.0011440237,0
|
| 1497 |
-
1,0.54906136,1
|
| 1498 |
-
0,0.4807756,0
|
| 1499 |
-
0,0.300868,0
|
| 1500 |
-
0,0.0011075152,0
|
| 1501 |
-
0,0.00085573347,0
|
| 1502 |
-
0,0.00085721305,0
|
| 1503 |
-
0,0.0010270772,0
|
| 1504 |
-
1,0.75990576,1
|
| 1505 |
-
0,0.06632882,0
|
| 1506 |
-
0,0.001487759,0
|
| 1507 |
-
1,0.20113204,0
|
| 1508 |
-
0,0.0013309364,0
|
| 1509 |
-
0,0.000907346,0
|
| 1510 |
-
1,0.5900306,1
|
| 1511 |
-
0,0.000946925,0
|
| 1512 |
-
0,0.0010906772,0
|
| 1513 |
-
0,0.41585302,0
|
| 1514 |
-
1,0.78238595,1
|
| 1515 |
-
0,0.0009065691,0
|
| 1516 |
-
0,0.0008776255,0
|
| 1517 |
-
1,0.788677,1
|
| 1518 |
-
0,0.001570436,0
|
| 1519 |
-
0,0.22623111,0
|
| 1520 |
-
0,0.0011954956,0
|
| 1521 |
-
0,0.26859984,0
|
| 1522 |
-
0,0.0012529476,0
|
| 1523 |
-
0,0.00084351,0
|
| 1524 |
-
0,0.0009662851,0
|
| 1525 |
-
0,0.0008385861,0
|
| 1526 |
-
0,0.0011187536,0
|
| 1527 |
-
0,0.000936137,0
|
| 1528 |
-
0,0.0012308148,0
|
| 1529 |
-
0,0.00084265514,0
|
| 1530 |
-
1,0.55672204,1
|
| 1531 |
-
1,0.7771148,1
|
| 1532 |
-
0,0.0012370128,0
|
| 1533 |
-
0,0.0010973206,0
|
| 1534 |
-
0,0.00087856123,0
|
| 1535 |
-
0,0.11740843,0
|
| 1536 |
-
0,0.0010494886,0
|
| 1537 |
-
0,0.0016682858,0
|
| 1538 |
-
0,0.0012532771,0
|
| 1539 |
-
0,0.001405361,0
|
| 1540 |
-
0,0.20851627,0
|
| 1541 |
-
1,0.6311097,1
|
| 1542 |
-
0,0.0008749682,0
|
| 1543 |
-
1,0.6411818,1
|
| 1544 |
-
1,0.83963007,1
|
| 1545 |
-
0,0.0011985023,0
|
| 1546 |
-
0,0.0010454545,0
|
| 1547 |
-
0,0.7701166,1
|
| 1548 |
-
0,0.00076997356,0
|
| 1549 |
-
1,0.8383424,1
|
| 1550 |
-
1,0.6781555,1
|
| 1551 |
-
0,0.0036841896,0
|
| 1552 |
-
1,0.59758455,1
|
| 1553 |
-
0,0.043098766,0
|
| 1554 |
-
0,0.0008827312,0
|
| 1555 |
-
0,0.00080906873,0
|
| 1556 |
-
1,0.78783256,1
|
| 1557 |
-
0,0.0011804759,0
|
| 1558 |
-
0,0.46424973,0
|
| 1559 |
-
0,0.0017560824,0
|
| 1560 |
-
0,0.004443426,0
|
| 1561 |
-
0,0.29807636,0
|
| 1562 |
-
0,0.0010315041,0
|
| 1563 |
-
0,0.5789729,1
|
| 1564 |
-
0,0.00084546005,0
|
| 1565 |
-
0,0.0011440341,0
|
| 1566 |
-
1,0.81396484,1
|
| 1567 |
-
1,0.56533325,1
|
| 1568 |
-
0,0.0011488326,0
|
| 1569 |
-
0,0.0012227881,0
|
| 1570 |
-
0,0.25655642,0
|
| 1571 |
-
1,0.5899143,1
|
| 1572 |
-
0,0.54285944,1
|
| 1573 |
-
0,0.0012625692,0
|
| 1574 |
-
0,0.00080501626,0
|
| 1575 |
-
0,0.0011473355,0
|
| 1576 |
-
0,0.0011238786,0
|
| 1577 |
-
0,0.0021833822,0
|
| 1578 |
-
0,0.0008517226,0
|
| 1579 |
-
0,0.0009454461,0
|
| 1580 |
-
0,0.0010944246,0
|
| 1581 |
-
0,0.024547417,0
|
| 1582 |
-
0,0.0010403111,0
|
| 1583 |
-
0,0.26350185,0
|
| 1584 |
-
0,0.0010104905,0
|
| 1585 |
-
0,0.09447297,0
|
| 1586 |
-
0,0.000852578,0
|
| 1587 |
-
0,0.0012083028,0
|
| 1588 |
-
0,0.6784862,1
|
| 1589 |
-
0,0.0009658967,0
|
| 1590 |
-
0,0.0010201805,0
|
| 1591 |
-
0,0.16826008,0
|
| 1592 |
-
0,0.0008753944,0
|
| 1593 |
-
0,0.00078190427,0
|
| 1594 |
-
0,0.20378338,0
|
| 1595 |
-
1,0.6095833,1
|
| 1596 |
-
1,0.55670387,1
|
| 1597 |
-
0,0.47983488,0
|
| 1598 |
-
1,0.24339448,0
|
| 1599 |
-
0,0.0013973623,0
|
| 1600 |
-
0,0.0008691309,0
|
| 1601 |
-
1,0.8703108,1
|
| 1602 |
-
0,0.002405853,0
|
| 1603 |
-
0,0.0011524003,0
|
| 1604 |
-
0,0.55783266,1
|
| 1605 |
-
0,0.0012722318,0
|
| 1606 |
-
0,0.00088787306,0
|
| 1607 |
-
0,0.1720217,0
|
| 1608 |
-
0,0.0009373888,0
|
| 1609 |
-
0,0.0043997974,0
|
| 1610 |
-
1,0.34884313,0
|
| 1611 |
-
1,0.3523087,0
|
| 1612 |
-
0,0.65511626,1
|
| 1613 |
-
0,0.0014629874,0
|
| 1614 |
-
0,0.37225887,0
|
| 1615 |
-
0,0.0012105026,0
|
| 1616 |
-
0,0.0010079421,0
|
| 1617 |
-
0,0.0010033519,0
|
| 1618 |
-
0,0.0008274714,0
|
| 1619 |
-
0,0.00087697804,0
|
| 1620 |
-
0,0.20703182,0
|
| 1621 |
-
0,0.0011603661,0
|
| 1622 |
-
0,0.0014855879,0
|
| 1623 |
-
0,0.22130914,0
|
| 1624 |
-
0,0.00086291565,0
|
| 1625 |
-
0,0.0008673242,0
|
| 1626 |
-
0,0.0011818307,0
|
| 1627 |
-
0,0.00096120656,0
|
| 1628 |
-
0,0.000923244,0
|
| 1629 |
-
0,0.4318089,0
|
| 1630 |
-
1,0.31608683,0
|
| 1631 |
-
1,0.77528465,1
|
| 1632 |
-
0,0.0013540791,0
|
| 1633 |
-
0,0.0008699616,0
|
| 1634 |
-
0,0.00094812346,0
|
| 1635 |
-
0,0.51795197,1
|
| 1636 |
-
1,0.61414665,1
|
| 1637 |
-
1,0.7352273,1
|
| 1638 |
-
0,0.00086347247,0
|
| 1639 |
-
0,0.0008687025,0
|
| 1640 |
-
0,0.0011694061,0
|
| 1641 |
-
0,0.0011693755,0
|
| 1642 |
-
0,0.455578,0
|
| 1643 |
-
0,0.105311655,0
|
| 1644 |
-
0,0.0008285432,0
|
| 1645 |
-
0,0.0010142302,0
|
| 1646 |
-
1,0.19479476,0
|
| 1647 |
-
0,0.0008654564,0
|
| 1648 |
-
0,0.0011336721,0
|
| 1649 |
-
0,0.0011340285,0
|
| 1650 |
-
0,0.24410512,0
|
| 1651 |
-
0,0.00090666034,0
|
| 1652 |
-
0,0.0015996577,0
|
| 1653 |
-
1,0.73918,1
|
| 1654 |
-
0,0.000983292,0
|
| 1655 |
-
0,0.0013622475,0
|
| 1656 |
-
0,0.0012581018,0
|
| 1657 |
-
0,0.00080613553,0
|
| 1658 |
-
0,0.0009217467,0
|
| 1659 |
-
1,0.26258734,0
|
| 1660 |
-
0,0.0012133989,0
|
| 1661 |
-
0,0.001480057,0
|
| 1662 |
-
0,0.7632006,1
|
| 1663 |
-
1,0.5682584,1
|
| 1664 |
-
0,0.0014863011,0
|
| 1665 |
-
0,0.63781154,1
|
| 1666 |
-
0,0.0011002115,0
|
| 1667 |
-
0,0.00084043556,0
|
| 1668 |
-
0,0.0014428858,0
|
| 1669 |
-
0,0.0013104859,0
|
| 1670 |
-
1,0.76846975,1
|
| 1671 |
-
1,0.6673162,1
|
| 1672 |
-
0,0.0008704569,0
|
| 1673 |
-
0,0.0008219127,0
|
| 1674 |
-
0,0.65152436,1
|
| 1675 |
-
0,0.0008993478,0
|
| 1676 |
-
0,0.0008565709,0
|
| 1677 |
-
0,0.14742456,0
|
| 1678 |
-
0,0.057693627,0
|
| 1679 |
-
0,0.0008922193,0
|
| 1680 |
-
1,0.15329692,0
|
| 1681 |
-
1,0.65068716,1
|
| 1682 |
-
0,0.0009436021,0
|
| 1683 |
-
0,0.009623958,0
|
| 1684 |
-
0,0.14724772,0
|
| 1685 |
-
0,0.0014856155,0
|
| 1686 |
-
0,0.635374,1
|
| 1687 |
-
0,0.53666717,1
|
| 1688 |
-
0,0.24563298,0
|
| 1689 |
-
0,0.46256322,0
|
| 1690 |
-
0,0.0009238498,0
|
| 1691 |
-
0,0.0008537007,0
|
| 1692 |
-
1,0.7335384,1
|
| 1693 |
-
0,0.00082300097,0
|
| 1694 |
-
1,0.444842,0
|
| 1695 |
-
0,0.0010771926,0
|
| 1696 |
-
0,0.00084987056,0
|
| 1697 |
-
0,0.0020735206,0
|
| 1698 |
-
0,0.0008576367,0
|
| 1699 |
-
0,0.00083232333,0
|
| 1700 |
-
0,0.0011556697,0
|
| 1701 |
-
0,0.0015070596,0
|
| 1702 |
-
0,0.0008087288,0
|
| 1703 |
-
0,0.0013067058,0
|
| 1704 |
-
0,0.21971786,0
|
| 1705 |
-
0,0.00081399345,0
|
| 1706 |
-
1,0.5227279,1
|
| 1707 |
-
0,0.0012097977,0
|
| 1708 |
-
0,0.001634093,0
|
| 1709 |
-
1,0.8266393,1
|
| 1710 |
-
1,0.79590076,1
|
| 1711 |
-
0,0.2257322,0
|
| 1712 |
-
1,0.6879368,1
|
| 1713 |
-
0,0.0008921309,0
|
| 1714 |
-
0,0.000969393,0
|
| 1715 |
-
1,0.77856237,1
|
| 1716 |
-
0,0.00078926905,0
|
| 1717 |
-
0,0.001172239,0
|
| 1718 |
-
1,0.40544796,0
|
| 1719 |
-
1,0.69128567,1
|
| 1720 |
-
0,0.0058039543,0
|
| 1721 |
-
0,0.0010676429,0
|
| 1722 |
-
0,0.0009675725,0
|
| 1723 |
-
0,0.24802871,0
|
| 1724 |
-
0,0.00088985654,0
|
| 1725 |
-
0,0.00090884947,0
|
| 1726 |
-
0,0.48389488,0
|
| 1727 |
-
0,0.67111087,1
|
| 1728 |
-
0,0.0011216454,0
|
| 1729 |
-
0,0.000996518,0
|
| 1730 |
-
1,0.7353937,1
|
| 1731 |
-
1,0.6577438,1
|
| 1732 |
-
1,0.7053728,1
|
| 1733 |
-
1,0.78178835,1
|
| 1734 |
-
0,0.615978,1
|
| 1735 |
-
1,0.6272983,1
|
| 1736 |
-
0,0.004934533,0
|
| 1737 |
-
0,0.0024336318,0
|
| 1738 |
-
0,0.00089967996,0
|
| 1739 |
-
1,0.5925518,1
|
| 1740 |
-
1,0.7473727,1
|
| 1741 |
-
0,0.3549951,0
|
| 1742 |
-
0,0.57792014,1
|
| 1743 |
-
1,0.6967723,1
|
| 1744 |
-
0,0.6283576,1
|
| 1745 |
-
0,0.3178549,0
|
| 1746 |
-
1,0.2248403,0
|
| 1747 |
-
1,0.81427705,1
|
| 1748 |
-
0,0.0010010622,0
|
| 1749 |
-
0,0.0013560583,0
|
| 1750 |
-
0,0.0015472178,0
|
| 1751 |
-
0,0.0009804085,0
|
| 1752 |
-
0,0.0007600732,0
|
| 1753 |
-
1,0.5372315,1
|
| 1754 |
-
0,0.0010776622,0
|
| 1755 |
-
0,0.00083396706,0
|
| 1756 |
-
0,0.0013004588,0
|
| 1757 |
-
0,0.5804702,1
|
| 1758 |
-
0,0.35824117,0
|
| 1759 |
-
0,0.001130411,0
|
| 1760 |
-
1,0.45011675,0
|
| 1761 |
-
0,0.0011523701,0
|
| 1762 |
-
0,0.0009201427,0
|
| 1763 |
-
0,0.0009954533,0
|
| 1764 |
-
0,0.00096182036,0
|
| 1765 |
-
0,0.00093874236,0
|
| 1766 |
-
0,0.4260264,0
|
| 1767 |
-
0,0.0012323499,0
|
| 1768 |
-
0,0.0014474612,0
|
| 1769 |
-
0,0.0009775738,0
|
| 1770 |
-
0,0.2297708,0
|
| 1771 |
-
0,0.69282055,1
|
| 1772 |
-
0,0.0009934746,0
|
| 1773 |
-
0,0.36029232,0
|
| 1774 |
-
1,0.36951026,0
|
| 1775 |
-
0,0.40954018,0
|
| 1776 |
-
0,0.00085924973,0
|
| 1777 |
-
0,0.003017331,0
|
| 1778 |
-
0,0.0016876189,0
|
| 1779 |
-
0,0.0010288198,0
|
| 1780 |
-
1,0.7446905,1
|
| 1781 |
-
0,0.24530704,0
|
| 1782 |
-
0,0.0010668626,0
|
| 1783 |
-
0,0.0037525713,0
|
| 1784 |
-
0,0.001123924,0
|
| 1785 |
-
0,0.0013527669,0
|
| 1786 |
-
0,0.30088598,0
|
| 1787 |
-
0,0.22248338,0
|
| 1788 |
-
0,0.0012103878,0
|
| 1789 |
-
0,0.40475878,0
|
| 1790 |
-
0,0.0008372259,0
|
| 1791 |
-
0,0.46901873,0
|
| 1792 |
-
0,0.0012301363,0
|
| 1793 |
-
0,0.0011264655,0
|
| 1794 |
-
0,0.00089167035,0
|
| 1795 |
-
0,0.00086598546,0
|
| 1796 |
-
0,0.0007683753,0
|
| 1797 |
-
0,0.120619036,0
|
| 1798 |
-
1,0.31516576,0
|
| 1799 |
-
0,0.0010315306,0
|
| 1800 |
-
1,0.82660633,1
|
| 1801 |
-
1,0.16573146,0
|
| 1802 |
-
1,0.45056126,0
|
| 1803 |
-
0,0.012582662,0
|
| 1804 |
-
0,0.2679602,0
|
| 1805 |
-
0,0.0007925501,0
|
| 1806 |
-
0,0.0017378323,0
|
| 1807 |
-
0,0.0017654862,0
|
| 1808 |
-
1,0.65987283,1
|
| 1809 |
-
1,0.6453965,1
|
| 1810 |
-
1,0.553943,1
|
| 1811 |
-
0,0.34653622,0
|
| 1812 |
-
0,0.0009486267,0
|
| 1813 |
-
0,0.0010880433,0
|
| 1814 |
-
1,0.64795786,1
|
| 1815 |
-
1,0.8529027,1
|
| 1816 |
-
1,0.537922,1
|
| 1817 |
-
0,0.0017200281,0
|
| 1818 |
-
0,0.0008353453,0
|
| 1819 |
-
1,0.75242555,1
|
| 1820 |
-
0,0.0007492936,0
|
| 1821 |
-
0,0.38131416,0
|
| 1822 |
-
0,0.61270493,1
|
| 1823 |
-
0,0.3288878,0
|
| 1824 |
-
0,0.6910125,1
|
| 1825 |
-
1,0.75235885,1
|
| 1826 |
-
1,0.7090527,1
|
| 1827 |
-
0,0.00084375514,0
|
| 1828 |
-
0,0.7937191,1
|
| 1829 |
-
0,0.00093711726,0
|
| 1830 |
-
1,0.5996378,1
|
| 1831 |
-
0,0.0009235945,0
|
| 1832 |
-
0,0.39941287,0
|
| 1833 |
-
1,0.42442018,0
|
| 1834 |
-
0,0.0010583275,0
|
| 1835 |
-
0,0.60000926,1
|
| 1836 |
-
0,0.0010812476,0
|
| 1837 |
-
0,0.0009984957,0
|
| 1838 |
-
1,0.8072236,1
|
| 1839 |
-
1,0.4488358,0
|
| 1840 |
-
0,0.0015813457,0
|
| 1841 |
-
1,0.2241001,0
|
| 1842 |
-
1,0.75282294,1
|
| 1843 |
-
0,0.000868323,0
|
| 1844 |
-
1,0.6267692,1
|
| 1845 |
-
0,0.00089177827,0
|
| 1846 |
-
0,0.001014311,0
|
| 1847 |
-
0,0.0008509839,0
|
| 1848 |
-
0,0.5767851,1
|
| 1849 |
-
0,0.0007812928,0
|
| 1850 |
-
0,0.0012110077,0
|
| 1851 |
-
0,0.0009440347,0
|
| 1852 |
-
1,0.7179761,1
|
| 1853 |
-
1,0.4313237,0
|
| 1854 |
-
0,0.0010092129,0
|
| 1855 |
-
0,0.00078347983,0
|
| 1856 |
-
0,0.0011101945,0
|
| 1857 |
-
1,0.53682137,1
|
| 1858 |
-
0,0.0009936662,0
|
| 1859 |
-
0,0.0011565915,0
|
| 1860 |
-
0,0.0010442929,0
|
| 1861 |
-
1,0.39957625,0
|
| 1862 |
-
0,0.001340685,0
|
| 1863 |
-
1,0.5603316,1
|
| 1864 |
-
0,0.22706328,0
|
| 1865 |
-
0,0.0009566926,0
|
| 1866 |
-
0,0.43829724,0
|
| 1867 |
-
0,0.0008824684,0
|
| 1868 |
-
0,0.0011820873,0
|
| 1869 |
-
0,0.0008683048,0
|
| 1870 |
-
0,0.00074347574,0
|
| 1871 |
-
0,0.51406115,1
|
| 1872 |
-
0,0.0025957846,0
|
| 1873 |
-
0,0.0010757236,0
|
| 1874 |
-
1,0.59766775,1
|
| 1875 |
-
0,0.0009575653,0
|
| 1876 |
-
0,0.0012813427,0
|
| 1877 |
-
0,0.001484598,0
|
| 1878 |
-
0,0.8911086,1
|
| 1879 |
-
1,0.41940725,0
|
| 1880 |
-
0,0.104963206,0
|
| 1881 |
-
0,0.22381157,0
|
| 1882 |
-
0,0.50649196,1
|
| 1883 |
-
0,0.0010766535,0
|
| 1884 |
-
0,0.00096266886,0
|
| 1885 |
-
0,0.0012597787,0
|
| 1886 |
-
0,0.0008376636,0
|
| 1887 |
-
0,0.0021784822,0
|
| 1888 |
-
0,0.0014312923,0
|
| 1889 |
-
0,0.0009531075,0
|
| 1890 |
-
0,0.0009719756,0
|
| 1891 |
-
1,0.5082326,1
|
| 1892 |
-
1,0.5103978,1
|
| 1893 |
-
0,0.4345131,0
|
| 1894 |
-
0,0.0010050534,0
|
| 1895 |
-
0,0.0008835647,0
|
| 1896 |
-
0,0.0007883947,0
|
| 1897 |
-
0,0.0022930545,0
|
| 1898 |
-
0,0.0029679493,0
|
| 1899 |
-
0,0.0011517105,0
|
| 1900 |
-
0,0.30619615,0
|
| 1901 |
-
0,0.093977414,0
|
| 1902 |
-
0,0.00095720595,0
|
| 1903 |
-
0,0.0015321785,0
|
| 1904 |
-
1,0.63398594,1
|
| 1905 |
-
0,0.000823775,0
|
| 1906 |
-
0,0.00079622905,0
|
| 1907 |
-
0,0.0009634666,0
|
| 1908 |
-
0,0.49582902,0
|
| 1909 |
-
0,0.0009126799,0
|
| 1910 |
-
0,0.0009068175,0
|
| 1911 |
-
0,0.0008108485,0
|
| 1912 |
-
0,0.0016464454,0
|
| 1913 |
-
0,0.17892216,0
|
| 1914 |
-
0,0.0009123811,0
|
| 1915 |
-
0,0.5026859,1
|
| 1916 |
-
0,0.00095499237,0
|
| 1917 |
-
0,0.0007871654,0
|
| 1918 |
-
0,0.001076556,0
|
| 1919 |
-
0,0.0011965865,0
|
| 1920 |
-
1,0.34323403,0
|
| 1921 |
-
0,0.87479687,1
|
| 1922 |
-
0,0.0008861932,0
|
| 1923 |
-
0,0.35567224,0
|
| 1924 |
-
0,0.36455706,0
|
| 1925 |
-
0,0.0010463424,0
|
| 1926 |
-
0,0.13184956,0
|
| 1927 |
-
0,0.001133562,0
|
| 1928 |
-
0,0.2384513,0
|
| 1929 |
-
0,0.002471907,0
|
| 1930 |
-
0,0.5236164,1
|
| 1931 |
-
1,0.7201213,1
|
| 1932 |
-
0,0.20469072,0
|
| 1933 |
-
1,0.3963374,0
|
| 1934 |
-
1,0.52494746,1
|
| 1935 |
-
0,0.14653996,0
|
| 1936 |
-
0,0.0008450751,0
|
| 1937 |
-
0,0.075510696,0
|
| 1938 |
-
1,0.23577085,0
|
| 1939 |
-
1,0.2894605,0
|
| 1940 |
-
0,0.0015299466,0
|
| 1941 |
-
0,0.7472723,1
|
| 1942 |
-
1,0.71758413,1
|
| 1943 |
-
1,0.63296187,1
|
| 1944 |
-
0,0.0013299449,0
|
| 1945 |
-
0,0.777783,1
|
| 1946 |
-
1,0.6209929,1
|
| 1947 |
-
0,0.00093404716,0
|
| 1948 |
-
0,0.58170086,1
|
| 1949 |
-
0,0.7959799,1
|
| 1950 |
-
0,0.0010129189,0
|
| 1951 |
-
0,0.0009979403,0
|
| 1952 |
-
1,0.44537392,0
|
| 1953 |
-
0,0.561583,1
|
| 1954 |
-
0,0.000845328,0
|
| 1955 |
-
0,0.0010346215,0
|
| 1956 |
-
1,0.67326033,1
|
| 1957 |
-
0,0.0009218446,0
|
| 1958 |
-
0,0.0010009438,0
|
| 1959 |
-
1,0.7646678,1
|
| 1960 |
-
0,0.6851708,1
|
| 1961 |
-
0,0.34248728,0
|
| 1962 |
-
0,0.0010312623,0
|
| 1963 |
-
0,0.001760763,0
|
| 1964 |
-
0,0.35641682,0
|
| 1965 |
-
0,0.0016838545,0
|
| 1966 |
-
0,0.6281676,1
|
| 1967 |
-
1,0.4627605,0
|
| 1968 |
-
0,0.0015306249,0
|
| 1969 |
-
1,0.23714353,0
|
| 1970 |
-
1,0.71622694,1
|
| 1971 |
-
1,0.7319818,1
|
| 1972 |
-
0,0.00108557,0
|
| 1973 |
-
0,0.5537684,1
|
| 1974 |
-
0,0.0012903535,0
|
| 1975 |
-
0,0.0012236463,0
|
| 1976 |
-
0,0.0012290003,0
|
| 1977 |
-
0,0.0008847146,0
|
| 1978 |
-
0,0.17526501,0
|
| 1979 |
-
1,0.6066298,1
|
| 1980 |
-
1,0.6446647,1
|
| 1981 |
-
0,0.00080044393,0
|
| 1982 |
-
0,0.098615125,0
|
| 1983 |
-
1,0.5616812,1
|
| 1984 |
-
0,0.6933321,1
|
| 1985 |
-
0,0.0011550131,0
|
| 1986 |
-
0,0.0011100766,0
|
| 1987 |
-
0,0.5804321,1
|
| 1988 |
-
0,0.0010168053,0
|
| 1989 |
-
0,0.22010425,0
|
| 1990 |
-
1,0.32657132,0
|
| 1991 |
-
0,0.0009580052,0
|
| 1992 |
-
1,0.5661701,1
|
| 1993 |
-
0,0.8999228,1
|
| 1994 |
-
0,0.0010242854,0
|
| 1995 |
-
0,0.00079186546,0
|
| 1996 |
-
0,0.0008344046,0
|
| 1997 |
-
0,0.006115876,0
|
| 1998 |
-
0,0.6555151,1
|
| 1999 |
-
0,0.47155666,0
|
| 2000 |
-
0,0.0011539217,0
|
| 2001 |
-
0,0.42844838,0
|
| 2002 |
-
0,0.00095852744,0
|
| 2003 |
-
0,0.0011793413,0
|
| 2004 |
-
0,0.5727451,1
|
| 2005 |
-
0,0.0011716434,0
|
| 2006 |
-
0,0.017039847,0
|
| 2007 |
-
0,0.0011528861,0
|
| 2008 |
-
0,0.0008173667,0
|
| 2009 |
-
0,0.0068512554,0
|
| 2010 |
-
1,0.752859,1
|
| 2011 |
-
1,0.67884725,1
|
| 2012 |
-
0,0.002062399,0
|
| 2013 |
-
0,0.0008773194,0
|
| 2014 |
-
0,0.4187466,0
|
| 2015 |
-
0,0.0008783964,0
|
| 2016 |
-
0,0.0011804664,0
|
| 2017 |
-
0,0.00090095046,0
|
| 2018 |
-
0,0.0008461568,0
|
| 2019 |
-
0,0.00090098643,0
|
| 2020 |
-
1,0.4443772,0
|
| 2021 |
-
0,0.0012510206,0
|
| 2022 |
-
0,0.0010878003,0
|
| 2023 |
-
0,0.001125162,0
|
| 2024 |
-
0,0.0009025954,0
|
| 2025 |
-
0,0.00082608557,0
|
| 2026 |
-
0,0.0012391479,0
|
| 2027 |
-
0,0.0010274188,0
|
| 2028 |
-
0,0.00085046276,0
|
| 2029 |
-
0,0.00094802276,0
|
| 2030 |
-
1,0.20541348,0
|
| 2031 |
-
1,0.37556043,0
|
| 2032 |
-
0,0.40013596,0
|
| 2033 |
-
0,0.6528374,1
|
| 2034 |
-
0,0.6438984,1
|
| 2035 |
-
0,0.33941498,0
|
| 2036 |
-
1,0.6012913,1
|
| 2037 |
-
1,0.40081245,0
|
| 2038 |
-
0,0.0009239743,0
|
| 2039 |
-
0,0.00088827714,0
|
| 2040 |
-
1,0.69142634,1
|
| 2041 |
-
0,0.00095390016,0
|
| 2042 |
-
0,0.00091235427,0
|
| 2043 |
-
0,0.46778736,0
|
| 2044 |
-
1,0.6588788,1
|
| 2045 |
-
0,0.00085993443,0
|
| 2046 |
-
0,0.00095059595,0
|
| 2047 |
-
0,0.31993032,0
|
| 2048 |
-
1,0.31242043,0
|
| 2049 |
-
1,0.7491844,1
|
| 2050 |
-
0,0.00084422436,0
|
| 2051 |
-
1,0.81296134,1
|
| 2052 |
-
0,0.0009211382,0
|
| 2053 |
-
1,0.45493677,0
|
| 2054 |
-
0,0.0011309453,0
|
| 2055 |
-
1,0.27920976,0
|
| 2056 |
-
0,0.001766247,0
|
| 2057 |
-
1,0.6483464,1
|
| 2058 |
-
0,0.4439611,0
|
| 2059 |
-
0,0.28825438,0
|
| 2060 |
-
0,0.0013631671,0
|
| 2061 |
-
0,0.001132262,0
|
| 2062 |
-
0,0.0013900561,0
|
| 2063 |
-
0,0.0015789722,0
|
| 2064 |
-
0,0.0017141184,0
|
| 2065 |
-
1,0.69817054,1
|
| 2066 |
-
0,0.011650326,0
|
| 2067 |
-
0,0.001025334,0
|
| 2068 |
-
0,0.0011714353,0
|
| 2069 |
-
0,0.0008079835,0
|
| 2070 |
-
1,0.5759009,1
|
| 2071 |
-
0,0.22395323,0
|
| 2072 |
-
1,0.79228,1
|
| 2073 |
-
0,0.0011739928,0
|
| 2074 |
-
0,0.00091123086,0
|
| 2075 |
-
0,0.00080784655,0
|
| 2076 |
-
0,0.0010515592,0
|
| 2077 |
-
0,0.39967638,0
|
| 2078 |
-
0,0.0010145825,0
|
| 2079 |
-
1,0.5423705,1
|
| 2080 |
-
1,0.74459624,1
|
| 2081 |
-
0,0.00082461955,0
|
| 2082 |
-
0,0.0029421886,0
|
| 2083 |
-
0,0.43035188,0
|
| 2084 |
-
0,0.6877102,1
|
| 2085 |
-
0,0.0009364068,0
|
| 2086 |
-
0,0.0012012246,0
|
| 2087 |
-
0,0.000867792,0
|
| 2088 |
-
0,0.15176591,0
|
| 2089 |
-
1,0.77357423,1
|
| 2090 |
-
0,0.0034248335,0
|
| 2091 |
-
0,0.00086901913,0
|
| 2092 |
-
0,0.027418558,0
|
| 2093 |
-
0,0.00096116075,0
|
| 2094 |
-
1,0.6818173,1
|
| 2095 |
-
0,0.0009722299,0
|
| 2096 |
-
1,0.8720049,1
|
| 2097 |
-
0,0.00084275793,0
|
| 2098 |
-
0,0.015429355,0
|
| 2099 |
-
0,0.0009900979,0
|
| 2100 |
-
0,0.0008618162,0
|
| 2101 |
-
0,0.0010659914,0
|
| 2102 |
-
0,0.0011133268,0
|
| 2103 |
-
1,0.6155381,1
|
| 2104 |
-
0,0.0019113526,0
|
| 2105 |
-
0,0.0013318051,0
|
| 2106 |
-
0,0.0021403905,0
|
| 2107 |
-
0,0.0013350251,0
|
| 2108 |
-
0,0.0011662951,0
|
| 2109 |
-
0,0.0009040375,0
|
| 2110 |
-
0,0.0029542346,0
|
| 2111 |
-
0,0.0026244,0
|
| 2112 |
-
0,0.20788012,0
|
| 2113 |
-
0,0.5909254,1
|
| 2114 |
-
0,0.0010929306,0
|
| 2115 |
-
0,0.08327696,0
|
| 2116 |
-
0,0.001310451,0
|
| 2117 |
-
0,0.14910121,0
|
| 2118 |
-
0,0.0011818194,0
|
| 2119 |
-
0,0.001830041,0
|
| 2120 |
-
1,0.69931924,1
|
| 2121 |
-
0,0.0028769197,0
|
| 2122 |
-
0,0.38193005,0
|
| 2123 |
-
0,0.0017126591,0
|
| 2124 |
-
0,0.0008859742,0
|
| 2125 |
-
0,0.41827688,0
|
| 2126 |
-
0,0.0011043509,0
|
| 2127 |
-
0,0.27034986,0
|
| 2128 |
-
0,0.0010202782,0
|
| 2129 |
-
0,0.0010163564,0
|
| 2130 |
-
1,0.85359937,1
|
| 2131 |
-
0,0.001006123,0
|
| 2132 |
-
0,0.00084935484,0
|
| 2133 |
-
0,0.011224475,0
|
| 2134 |
-
0,0.0010811862,0
|
| 2135 |
-
0,0.0010717752,0
|
| 2136 |
-
0,0.3722477,0
|
| 2137 |
-
0,0.0010041799,0
|
| 2138 |
-
0,0.0022786073,0
|
| 2139 |
-
0,0.0010501673,0
|
| 2140 |
-
0,0.44608837,0
|
| 2141 |
-
0,0.00079554186,0
|
| 2142 |
-
0,0.49635226,0
|
| 2143 |
-
0,0.4600765,0
|
| 2144 |
-
1,0.31137487,0
|
| 2145 |
-
0,0.0009268139,0
|
| 2146 |
-
0,0.001196574,0
|
| 2147 |
-
0,0.001026193,0
|
| 2148 |
-
0,0.0009154746,0
|
| 2149 |
-
1,0.33101282,0
|
| 2150 |
-
0,0.0011979094,0
|
| 2151 |
-
0,0.0014616075,0
|
| 2152 |
-
0,0.0014142476,0
|
| 2153 |
-
1,0.8949496,1
|
| 2154 |
-
1,0.8674389,1
|
| 2155 |
-
0,0.0008275794,0
|
| 2156 |
-
0,0.00087104103,0
|
| 2157 |
-
0,0.0010423438,0
|
| 2158 |
-
0,0.31495747,0
|
| 2159 |
-
0,0.0014686183,0
|
| 2160 |
-
1,0.67044705,1
|
| 2161 |
-
0,0.2931136,0
|
| 2162 |
-
0,0.04735578,0
|
| 2163 |
-
0,0.09337471,0
|
| 2164 |
-
0,0.0009415873,0
|
| 2165 |
-
0,0.0010099063,0
|
| 2166 |
-
0,0.53362274,1
|
| 2167 |
-
0,0.00096574856,0
|
| 2168 |
-
0,0.0012844005,0
|
| 2169 |
-
0,0.00083572033,0
|
| 2170 |
-
1,0.8030804,1
|
| 2171 |
-
0,0.0010202461,0
|
| 2172 |
-
0,0.0013574356,0
|
| 2173 |
-
1,0.6834372,1
|
| 2174 |
-
0,0.09290696,0
|
| 2175 |
-
0,0.0009394532,0
|
| 2176 |
-
0,0.0009227664,0
|
| 2177 |
-
1,0.50783664,1
|
| 2178 |
-
1,0.8649702,1
|
| 2179 |
-
1,0.7761252,1
|
| 2180 |
-
0,0.0012978833,0
|
| 2181 |
-
0,0.002984829,0
|
| 2182 |
-
0,0.0010553977,0
|
| 2183 |
-
0,0.42205676,0
|
| 2184 |
-
1,0.77211195,1
|
| 2185 |
-
0,0.001033883,0
|
| 2186 |
-
0,0.63332963,1
|
| 2187 |
-
0,0.0014111141,0
|
| 2188 |
-
0,0.0010597117,0
|
| 2189 |
-
1,0.26766336,0
|
| 2190 |
-
0,0.3339755,0
|
| 2191 |
-
0,0.16533276,0
|
| 2192 |
-
1,0.6683201,1
|
| 2193 |
-
0,0.0021313336,0
|
| 2194 |
-
0,0.001062777,0
|
| 2195 |
-
1,0.58701336,1
|
| 2196 |
-
0,0.000942338,0
|
| 2197 |
-
0,0.16852604,0
|
| 2198 |
-
0,0.6969176,1
|
| 2199 |
-
0,0.001123046,0
|
| 2200 |
-
1,0.6442029,1
|
| 2201 |
-
0,0.0031450344,0
|
| 2202 |
-
0,0.0009817418,0
|
| 2203 |
-
0,0.00335572,0
|
| 2204 |
-
0,0.001056942,0
|
| 2205 |
-
1,0.59806097,1
|
| 2206 |
-
1,0.79247564,1
|
| 2207 |
-
0,0.0009922732,0
|
| 2208 |
-
0,0.46785074,0
|
| 2209 |
-
0,0.22426423,0
|
| 2210 |
-
0,0.41550258,0
|
| 2211 |
-
0,0.0008828788,0
|
| 2212 |
-
0,0.0008916164,0
|
| 2213 |
-
1,0.5482998,1
|
| 2214 |
-
0,0.0011476376,0
|
| 2215 |
-
0,0.8011317,1
|
| 2216 |
-
0,0.000926976,0
|
| 2217 |
-
0,0.0009660738,0
|
| 2218 |
-
0,0.06638102,0
|
| 2219 |
-
0,0.0012834723,0
|
| 2220 |
-
0,0.00091870327,0
|
| 2221 |
-
0,0.0008933465,0
|
| 2222 |
-
0,0.0031780212,0
|
| 2223 |
-
0,0.0016238751,0
|
| 2224 |
-
0,0.3595388,0
|
| 2225 |
-
0,0.6696774,1
|
| 2226 |
-
0,0.40328088,0
|
| 2227 |
-
0,0.013335245,0
|
| 2228 |
-
0,0.0010570795,0
|
| 2229 |
-
0,0.0013795571,0
|
| 2230 |
-
0,0.16846372,0
|
| 2231 |
-
1,0.84474844,1
|
| 2232 |
-
0,0.0008750753,0
|
| 2233 |
-
0,0.00087973947,0
|
| 2234 |
-
1,0.50548005,1
|
| 2235 |
-
0,0.0008324304,0
|
| 2236 |
-
0,0.00094381167,0
|
| 2237 |
-
0,0.00089194905,0
|
| 2238 |
-
0,0.5607766,1
|
| 2239 |
-
1,0.54241484,1
|
| 2240 |
-
0,0.00443618,0
|
| 2241 |
-
0,0.36638293,0
|
| 2242 |
-
0,0.7677964,1
|
| 2243 |
-
0,0.37992284,0
|
| 2244 |
-
1,0.70686316,1
|
| 2245 |
-
0,0.5411161,1
|
| 2246 |
-
1,0.34657082,0
|
| 2247 |
-
1,0.42200235,0
|
| 2248 |
-
0,0.0010511446,0
|
| 2249 |
-
0,0.0011092986,0
|
| 2250 |
-
0,0.10377656,0
|
| 2251 |
-
0,0.14405686,0
|
| 2252 |
-
0,0.001046122,0
|
| 2253 |
-
0,0.0011387641,0
|
| 2254 |
-
0,0.054750968,0
|
| 2255 |
-
1,0.84262145,1
|
| 2256 |
-
0,0.0011182426,0
|
| 2257 |
-
0,0.70416677,1
|
| 2258 |
-
0,0.001527749,0
|
| 2259 |
-
1,0.39897177,0
|
| 2260 |
-
0,0.0014704145,0
|
| 2261 |
-
0,0.00088706374,0
|
| 2262 |
-
0,0.0008660363,0
|
| 2263 |
-
0,0.0020009428,0
|
| 2264 |
-
1,0.7243906,1
|
| 2265 |
-
0,0.0011177352,0
|
| 2266 |
-
1,0.58792394,1
|
| 2267 |
-
0,0.0012186854,0
|
| 2268 |
-
0,0.000997782,0
|
| 2269 |
-
1,0.50993186,1
|
| 2270 |
-
1,0.08566585,0
|
| 2271 |
-
0,0.0009616673,0
|
| 2272 |
-
0,0.3577684,0
|
| 2273 |
-
0,0.00096860697,0
|
| 2274 |
-
0,0.446375,0
|
| 2275 |
-
0,0.0008482884,0
|
| 2276 |
-
0,0.0009922902,0
|
| 2277 |
-
0,0.6955276,1
|
| 2278 |
-
0,0.5420666,1
|
| 2279 |
-
0,0.0011368056,0
|
| 2280 |
-
0,0.000882534,0
|
| 2281 |
-
0,0.0009420459,0
|
| 2282 |
-
0,0.0014183329,0
|
| 2283 |
-
0,0.000876334,0
|
| 2284 |
-
0,0.00083285803,0
|
| 2285 |
-
0,0.0013245676,0
|
| 2286 |
-
0,0.00087310805,0
|
| 2287 |
-
0,0.0013295789,0
|
| 2288 |
-
0,0.6923525,1
|
| 2289 |
-
0,0.0010767679,0
|
| 2290 |
-
0,0.005409557,0
|
| 2291 |
-
0,0.0009981047,0
|
| 2292 |
-
0,0.015847243,0
|
| 2293 |
-
0,0.0013694075,0
|
| 2294 |
-
0,0.0007474202,0
|
| 2295 |
-
1,0.506777,1
|
| 2296 |
-
0,0.009123179,0
|
| 2297 |
-
0,0.49657914,0
|
| 2298 |
-
0,0.277546,0
|
| 2299 |
-
0,0.16679445,0
|
| 2300 |
-
0,0.0013762934,0
|
| 2301 |
-
0,0.0011636585,0
|
| 2302 |
-
0,0.0101664085,0
|
| 2303 |
-
0,0.0010100742,0
|
| 2304 |
-
0,0.00094104477,0
|
| 2305 |
-
0,0.32086867,0
|
| 2306 |
-
0,0.75031704,1
|
| 2307 |
-
0,0.0014099678,0
|
| 2308 |
-
0,0.0013869625,0
|
| 2309 |
-
0,0.0016245442,0
|
| 2310 |
-
0,0.0010221949,0
|
| 2311 |
-
1,0.49876264,0
|
| 2312 |
-
1,0.38857585,0
|
| 2313 |
-
0,0.0020588983,0
|
| 2314 |
-
0,0.0015847568,0
|
| 2315 |
-
1,0.6566301,1
|
| 2316 |
-
0,0.00082876394,0
|
| 2317 |
-
0,0.0008605746,0
|
| 2318 |
-
0,0.7650169,1
|
| 2319 |
-
1,0.34104934,0
|
| 2320 |
-
0,0.68121535,1
|
| 2321 |
-
0,0.0011083046,0
|
| 2322 |
-
0,0.0019450967,0
|
| 2323 |
-
1,0.51100874,1
|
| 2324 |
-
0,0.0011186971,0
|
| 2325 |
-
0,0.5587132,1
|
| 2326 |
-
0,0.00088953995,0
|
| 2327 |
-
0,0.0011756523,0
|
| 2328 |
-
1,0.65532446,1
|
| 2329 |
-
0,0.0009612486,0
|
| 2330 |
-
0,0.0021935876,0
|
| 2331 |
-
1,0.5830356,1
|
| 2332 |
-
0,0.0012566766,0
|
| 2333 |
-
1,0.60258913,1
|
| 2334 |
-
0,0.0011863246,0
|
| 2335 |
-
0,0.0009032179,0
|
| 2336 |
-
0,0.65853,1
|
| 2337 |
-
0,0.000863914,0
|
| 2338 |
-
0,0.0010347526,0
|
| 2339 |
-
1,0.7523419,1
|
| 2340 |
-
1,0.72493607,1
|
| 2341 |
-
0,0.0007886984,0
|
| 2342 |
-
0,0.0011739223,0
|
| 2343 |
-
0,0.0030585844,0
|
| 2344 |
-
0,0.0012207521,0
|
| 2345 |
-
0,0.00080916967,0
|
| 2346 |
-
1,0.64112777,1
|
| 2347 |
-
0,0.52794874,1
|
| 2348 |
-
0,0.0010475264,0
|
| 2349 |
-
0,0.6387218,1
|
| 2350 |
-
0,0.0008685981,0
|
| 2351 |
-
0,0.0009589503,0
|
| 2352 |
-
1,0.55499166,1
|
| 2353 |
-
0,0.44648212,0
|
| 2354 |
-
0,0.03675145,0
|
| 2355 |
-
0,0.2166356,0
|
| 2356 |
-
0,0.0012044515,0
|
| 2357 |
-
0,0.0013222471,0
|
| 2358 |
-
0,0.4260182,0
|
| 2359 |
-
0,0.2610802,0
|
| 2360 |
-
1,0.19833244,0
|
| 2361 |
-
0,0.00086645305,0
|
| 2362 |
-
1,0.5524298,1
|
| 2363 |
-
0,0.0009163962,0
|
| 2364 |
-
1,0.85790485,1
|
| 2365 |
-
0,0.0013663898,0
|
| 2366 |
-
0,0.0014356149,0
|
| 2367 |
-
1,0.6442424,1
|
| 2368 |
-
0,0.71102035,1
|
| 2369 |
-
1,0.5645003,1
|
| 2370 |
-
0,0.0012521823,0
|
| 2371 |
-
0,0.00092533894,0
|
| 2372 |
-
0,0.0018732402,0
|
| 2373 |
-
1,0.54946786,1
|
| 2374 |
-
0,0.0012147287,0
|
| 2375 |
-
0,0.0019167685,0
|
| 2376 |
-
0,0.0019650552,0
|
| 2377 |
-
0,0.0010118411,0
|
| 2378 |
-
0,0.001942437,0
|
| 2379 |
-
1,0.74532837,1
|
| 2380 |
-
0,0.73756206,1
|
| 2381 |
-
1,0.6251341,1
|
| 2382 |
-
0,0.27416506,0
|
| 2383 |
-
1,0.50981134,1
|
| 2384 |
-
0,0.5543468,1
|
| 2385 |
-
0,0.0009716256,0
|
| 2386 |
-
0,0.0007795425,0
|
| 2387 |
-
0,0.0019102368,0
|
| 2388 |
-
0,0.0013252186,0
|
| 2389 |
-
0,0.0011993565,0
|
| 2390 |
-
0,0.4715206,0
|
| 2391 |
-
0,0.060475998,0
|
| 2392 |
-
0,0.0010275397,0
|
| 2393 |
-
0,0.09666838,0
|
| 2394 |
-
0,0.00092452514,0
|
| 2395 |
-
0,0.00089867495,0
|
| 2396 |
-
0,0.00094824354,0
|
| 2397 |
-
0,0.59952736,1
|
| 2398 |
-
1,0.3576,0
|
| 2399 |
-
0,0.0012081462,0
|
| 2400 |
-
0,0.56758493,1
|
| 2401 |
-
1,0.6091145,1
|
| 2402 |
-
0,0.0028927878,0
|
| 2403 |
-
0,0.00090088655,0
|
| 2404 |
-
0,0.0022922496,0
|
| 2405 |
-
0,0.0008741644,0
|
| 2406 |
-
0,0.0010697065,0
|
| 2407 |
-
0,0.23075378,0
|
| 2408 |
-
0,0.2533621,0
|
| 2409 |
-
1,0.7158058,1
|
| 2410 |
-
0,0.00092310895,0
|
| 2411 |
-
0,0.0008439115,0
|
| 2412 |
-
0,0.29364723,0
|
| 2413 |
-
0,0.3924025,0
|
| 2414 |
-
0,0.00081181526,0
|
| 2415 |
-
0,0.0016517383,0
|
| 2416 |
-
1,0.65525025,1
|
| 2417 |
-
0,0.0009397724,0
|
| 2418 |
-
0,0.0010028163,0
|
| 2419 |
-
0,0.0010158733,0
|
| 2420 |
-
1,0.5670775,1
|
| 2421 |
-
0,0.00096147077,0
|
| 2422 |
-
0,0.6200651,1
|
| 2423 |
-
0,0.0022722506,0
|
| 2424 |
-
1,0.708499,1
|
| 2425 |
-
0,0.335328,0
|
| 2426 |
-
0,0.0012308101,0
|
| 2427 |
-
0,0.52159727,1
|
| 2428 |
-
0,0.68693334,1
|
| 2429 |
-
1,0.78105223,1
|
| 2430 |
-
0,0.146075,0
|
| 2431 |
-
0,0.22292003,0
|
| 2432 |
-
0,0.4519507,0
|
| 2433 |
-
0,0.00090513856,0
|
| 2434 |
-
1,0.7193635,1
|
| 2435 |
-
0,0.0008796326,0
|
| 2436 |
-
0,0.5841723,1
|
| 2437 |
-
0,0.00097607274,0
|
| 2438 |
-
0,0.49114498,0
|
| 2439 |
-
0,0.32498598,0
|
| 2440 |
-
0,0.0007685969,0
|
| 2441 |
-
0,0.0009030518,0
|
| 2442 |
-
0,0.0010781017,0
|
| 2443 |
-
0,0.0008947887,0
|
| 2444 |
-
0,0.00097482925,0
|
| 2445 |
-
0,0.0018749419,0
|
| 2446 |
-
1,0.81562716,1
|
| 2447 |
-
0,0.4467036,0
|
| 2448 |
-
0,0.0012841815,0
|
| 2449 |
-
1,0.8361843,1
|
| 2450 |
-
0,0.0028613487,0
|
| 2451 |
-
0,0.0009700805,0
|
| 2452 |
-
0,0.00091441197,0
|
| 2453 |
-
0,0.0011520925,0
|
| 2454 |
-
0,0.35427848,0
|
| 2455 |
-
0,0.0009871047,0
|
| 2456 |
-
0,0.0010851595,0
|
| 2457 |
-
0,0.0011901299,0
|
| 2458 |
-
1,0.68657166,1
|
| 2459 |
-
0,0.0008210955,0
|
| 2460 |
-
0,0.3926858,0
|
| 2461 |
-
0,0.0008118559,0
|
| 2462 |
-
0,0.001132336,0
|
| 2463 |
-
0,0.47707522,0
|
| 2464 |
-
0,0.0010266873,0
|
| 2465 |
-
1,0.49586618,0
|
| 2466 |
-
0,0.49493933,0
|
| 2467 |
-
0,0.0007915444,0
|
| 2468 |
-
0,0.0015369055,0
|
| 2469 |
-
0,0.001586034,0
|
| 2470 |
-
1,0.3755337,0
|
| 2471 |
-
0,0.42481285,0
|
| 2472 |
-
0,0.023645595,0
|
| 2473 |
-
0,0.0010042259,0
|
| 2474 |
-
0,0.00128545,0
|
| 2475 |
-
1,0.86342037,1
|
| 2476 |
-
0,0.52136,1
|
| 2477 |
-
0,0.00090365595,0
|
| 2478 |
-
0,0.00097448146,0
|
| 2479 |
-
0,0.109581724,0
|
| 2480 |
-
0,0.46571532,0
|
| 2481 |
-
0,0.00094029406,0
|
| 2482 |
-
0,0.101251364,0
|
| 2483 |
-
1,0.7810546,1
|
| 2484 |
-
0,0.46070743,0
|
| 2485 |
-
0,0.49183476,0
|
| 2486 |
-
1,0.5577063,1
|
| 2487 |
-
0,0.0015362684,0
|
| 2488 |
-
0,0.4263195,0
|
| 2489 |
-
0,0.7231806,1
|
| 2490 |
-
0,0.37002757,0
|
| 2491 |
-
0,0.752774,1
|
| 2492 |
-
0,0.000886504,0
|
| 2493 |
-
1,0.22578251,0
|
| 2494 |
-
1,0.4892348,0
|
| 2495 |
-
0,0.16326025,0
|
| 2496 |
-
0,0.67995006,1
|
| 2497 |
-
0,0.0012501619,0
|
| 2498 |
-
0,0.0010966564,0
|
| 2499 |
-
0,0.0032683455,0
|
| 2500 |
-
0,0.4806333,0
|
| 2501 |
-
0,0.00090476696,0
|
| 2502 |
-
0,0.0015114724,0
|
| 2503 |
-
0,0.0008801575,0
|
| 2504 |
-
0,0.0013020836,0
|
| 2505 |
-
0,0.0008198137,0
|
| 2506 |
-
0,0.39454362,0
|
| 2507 |
-
0,0.00093733915,0
|
| 2508 |
-
1,0.2374296,0
|
| 2509 |
-
0,0.4772929,0
|
| 2510 |
-
0,0.00083000946,0
|
| 2511 |
-
1,0.39207602,0
|
| 2512 |
-
1,0.21760413,0
|
| 2513 |
-
0,0.00088210235,0
|
| 2514 |
-
0,0.4371137,0
|
| 2515 |
-
0,0.003961796,0
|
| 2516 |
-
1,0.407404,0
|
| 2517 |
-
1,0.32632312,0
|
| 2518 |
-
0,0.0013358905,0
|
| 2519 |
-
0,0.44445646,0
|
| 2520 |
-
0,0.00078979187,0
|
| 2521 |
-
0,0.001217051,0
|
| 2522 |
-
1,0.14045572,0
|
| 2523 |
-
0,0.00090959703,0
|
| 2524 |
-
0,0.12760106,0
|
| 2525 |
-
0,0.00092583685,0
|
| 2526 |
-
0,0.0010166396,0
|
| 2527 |
-
0,0.0013670418,0
|
| 2528 |
-
0,0.0011671932,0
|
| 2529 |
-
0,0.0009024935,0
|
| 2530 |
-
0,0.0009488663,0
|
| 2531 |
-
0,0.23633154,0
|
| 2532 |
-
0,0.0009680432,0
|
| 2533 |
-
0,0.005821207,0
|
| 2534 |
-
0,0.001028018,0
|
| 2535 |
-
1,0.8682412,1
|
| 2536 |
-
0,0.0031583507,0
|
| 2537 |
-
0,0.0009497762,0
|
| 2538 |
-
0,0.001053368,0
|
| 2539 |
-
0,0.32452446,0
|
| 2540 |
-
0,0.2644262,0
|
| 2541 |
-
0,0.0016570487,0
|
| 2542 |
-
1,0.41961834,0
|
| 2543 |
-
0,0.002138197,0
|
| 2544 |
-
0,0.50461143,1
|
| 2545 |
-
1,0.31562546,0
|
| 2546 |
-
0,0.00085728854,0
|
| 2547 |
-
1,0.7799668,1
|
| 2548 |
-
1,0.60938144,1
|
| 2549 |
-
1,0.70686054,1
|
| 2550 |
-
0,0.14313947,0
|
| 2551 |
-
0,0.0019669319,0
|
| 2552 |
-
0,0.0014327079,0
|
| 2553 |
-
1,0.35944593,0
|
| 2554 |
-
0,0.001155221,0
|
| 2555 |
-
0,0.0012412227,0
|
| 2556 |
-
0,0.65372455,1
|
| 2557 |
-
0,0.09133776,0
|
| 2558 |
-
0,0.00088334584,0
|
| 2559 |
-
0,0.312847,0
|
| 2560 |
-
0,0.00094589975,0
|
| 2561 |
-
0,0.02113242,0
|
| 2562 |
-
0,0.00085028814,0
|
| 2563 |
-
1,0.679118,1
|
| 2564 |
-
0,0.0010827653,0
|
| 2565 |
-
0,0.6941961,1
|
| 2566 |
-
0,0.22313194,0
|
| 2567 |
-
0,0.545091,1
|
| 2568 |
-
0,0.0009083876,0
|
| 2569 |
-
0,0.46453863,0
|
| 2570 |
-
0,0.26321778,0
|
| 2571 |
-
1,0.75012404,1
|
| 2572 |
-
0,0.0009906325,0
|
| 2573 |
-
1,0.6568036,1
|
| 2574 |
-
0,0.0009753449,0
|
| 2575 |
-
0,0.07195825,0
|
| 2576 |
-
1,0.6351574,1
|
| 2577 |
-
0,0.1408972,0
|
| 2578 |
-
0,0.001222103,0
|
| 2579 |
-
1,0.77539057,1
|
| 2580 |
-
0,0.5444014,1
|
| 2581 |
-
0,0.0010118657,0
|
| 2582 |
-
1,0.48471546,0
|
| 2583 |
-
0,0.0016892834,0
|
| 2584 |
-
0,0.0010253623,0
|
| 2585 |
-
0,0.52112716,1
|
| 2586 |
-
0,0.0014120834,0
|
| 2587 |
-
0,0.00082746346,0
|
| 2588 |
-
0,0.027199497,0
|
| 2589 |
-
0,0.05421511,0
|
| 2590 |
-
1,0.31981403,0
|
| 2591 |
-
0,0.0008672601,0
|
| 2592 |
-
0,0.00398289,0
|
| 2593 |
-
0,0.00096599694,0
|
| 2594 |
-
0,0.0008861936,0
|
| 2595 |
-
0,0.001077711,0
|
| 2596 |
-
0,0.0040958757,0
|
| 2597 |
-
1,0.73662597,1
|
| 2598 |
-
1,0.7852967,1
|
| 2599 |
-
0,0.0009536717,0
|
| 2600 |
-
0,0.22212833,0
|
| 2601 |
-
0,0.0012771386,0
|
| 2602 |
-
0,0.14009744,0
|
| 2603 |
-
0,0.22057438,0
|
| 2604 |
-
1,0.57046175,1
|
| 2605 |
-
1,0.77926654,1
|
| 2606 |
-
0,0.00082518166,0
|
| 2607 |
-
0,0.0009310219,0
|
| 2608 |
-
0,0.00091517466,0
|
| 2609 |
-
0,0.00088250585,0
|
| 2610 |
-
0,0.5822396,1
|
| 2611 |
-
0,0.00081105594,0
|
| 2612 |
-
1,0.2343802,0
|
| 2613 |
-
0,0.0011367487,0
|
| 2614 |
-
0,0.00117123,0
|
| 2615 |
-
1,0.60446376,1
|
| 2616 |
-
1,0.31510806,0
|
| 2617 |
-
0,0.2251927,0
|
| 2618 |
-
1,0.6748813,1
|
| 2619 |
-
0,0.0010593968,0
|
| 2620 |
-
1,0.32355276,0
|
| 2621 |
-
0,0.0011763906,0
|
| 2622 |
-
0,0.0021751213,0
|
| 2623 |
-
0,0.0026256007,0
|
| 2624 |
-
0,0.0009737066,0
|
| 2625 |
-
1,0.75336325,1
|
| 2626 |
-
0,0.0010090931,0
|
| 2627 |
-
0,0.00094570237,0
|
| 2628 |
-
0,0.17809871,0
|
| 2629 |
-
0,0.0008351583,0
|
| 2630 |
-
1,0.6894662,1
|
| 2631 |
-
0,0.16999531,0
|
| 2632 |
-
0,0.0009210333,0
|
| 2633 |
-
0,0.00092500576,0
|
| 2634 |
-
1,0.6222378,1
|
| 2635 |
-
0,0.5345441,1
|
| 2636 |
-
0,0.10253339,0
|
| 2637 |
-
1,0.5398624,1
|
| 2638 |
-
1,0.70694554,1
|
| 2639 |
-
0,0.14037986,0
|
| 2640 |
-
1,0.73327917,1
|
| 2641 |
-
1,0.46298742,0
|
| 2642 |
-
0,0.0013656756,0
|
| 2643 |
-
0,0.0009719117,0
|
| 2644 |
-
0,0.32651603,0
|
| 2645 |
-
0,0.0010383307,0
|
| 2646 |
-
0,0.00082847173,0
|
| 2647 |
-
0,0.3687943,0
|
| 2648 |
-
0,0.17860247,0
|
| 2649 |
-
0,0.8183667,1
|
| 2650 |
-
0,0.00082963787,0
|
| 2651 |
-
0,0.80729526,1
|
| 2652 |
-
0,0.15151022,0
|
| 2653 |
-
0,0.0011345402,0
|
| 2654 |
-
0,0.00081908173,0
|
| 2655 |
-
0,0.39142883,0
|
| 2656 |
-
0,0.0016399269,0
|
| 2657 |
-
0,0.27851093,0
|
| 2658 |
-
1,0.7010973,1
|
| 2659 |
-
0,0.007820642,0
|
| 2660 |
-
0,0.0018496078,0
|
| 2661 |
-
0,0.0019736353,0
|
| 2662 |
-
0,0.0009546348,0
|
| 2663 |
-
0,0.0014451279,0
|
| 2664 |
-
1,0.5517119,1
|
| 2665 |
-
0,0.010127031,0
|
| 2666 |
-
0,0.0009856458,0
|
| 2667 |
-
0,0.00091644377,0
|
| 2668 |
-
0,0.00095582794,0
|
| 2669 |
-
0,0.0011918532,0
|
| 2670 |
-
0,0.0010818146,0
|
| 2671 |
-
0,0.0013008608,0
|
| 2672 |
-
0,0.45276558,0
|
| 2673 |
-
0,0.001402461,0
|
| 2674 |
-
0,0.0010138496,0
|
| 2675 |
-
0,0.0019704031,0
|
| 2676 |
-
0,0.43915036,0
|
| 2677 |
-
0,0.0036121928,0
|
| 2678 |
-
1,0.7637573,1
|
| 2679 |
-
0,0.0010554554,0
|
| 2680 |
-
0,0.001165051,0
|
| 2681 |
-
0,0.37836757,0
|
| 2682 |
-
1,0.38657185,0
|
| 2683 |
-
0,0.00091857504,0
|
| 2684 |
-
1,0.7066505,1
|
| 2685 |
-
0,0.0009607843,0
|
| 2686 |
-
0,0.0009117129,0
|
| 2687 |
-
0,0.5112803,1
|
| 2688 |
-
0,0.0009124838,0
|
| 2689 |
-
0,0.0007892426,0
|
| 2690 |
-
0,0.0008387871,0
|
| 2691 |
-
1,0.57059675,1
|
| 2692 |
-
0,0.0008657701,0
|
| 2693 |
-
0,0.7825752,1
|
| 2694 |
-
0,0.0010295792,0
|
| 2695 |
-
0,0.0008365171,0
|
| 2696 |
-
1,0.791482,1
|
| 2697 |
-
1,0.11556648,0
|
| 2698 |
-
0,0.035818722,0
|
| 2699 |
-
1,0.67135835,1
|
| 2700 |
-
1,0.8084183,1
|
| 2701 |
-
0,0.0013094479,0
|
| 2702 |
-
0,0.7851747,1
|
| 2703 |
-
0,0.0012969027,0
|
| 2704 |
-
1,0.6496594,1
|
| 2705 |
-
1,0.7269179,1
|
| 2706 |
-
0,0.0010638919,0
|
| 2707 |
-
0,0.0008244636,0
|
| 2708 |
-
1,0.2654086,0
|
| 2709 |
-
0,0.33399442,0
|
| 2710 |
-
0,0.0028439101,0
|
| 2711 |
-
1,0.44304168,0
|
| 2712 |
-
0,0.0008866658,0
|
| 2713 |
-
0,0.00084273505,0
|
| 2714 |
-
0,0.0020351016,0
|
| 2715 |
-
0,0.0008314609,0
|
| 2716 |
-
0,0.0012523923,0
|
| 2717 |
-
0,0.00095604156,0
|
| 2718 |
-
0,0.0017479339,0
|
| 2719 |
-
0,0.051447127,0
|
| 2720 |
-
0,0.0009557702,0
|
| 2721 |
-
0,0.0009667441,0
|
| 2722 |
-
0,0.002198847,0
|
| 2723 |
-
0,0.12981664,0
|
| 2724 |
-
0,0.0012121114,0
|
| 2725 |
-
0,0.0079048155,0
|
| 2726 |
-
0,0.00077453733,0
|
| 2727 |
-
0,0.61904925,1
|
| 2728 |
-
1,0.51487285,1
|
| 2729 |
-
1,0.71629065,1
|
| 2730 |
-
0,0.0010181725,0
|
| 2731 |
-
1,0.20445453,0
|
| 2732 |
-
0,0.0009306721,0
|
| 2733 |
-
0,0.00092179497,0
|
| 2734 |
-
1,0.53567547,1
|
| 2735 |
-
0,0.0008878104,0
|
| 2736 |
-
0,0.14483893,0
|
| 2737 |
-
0,0.0022033292,0
|
| 2738 |
-
0,0.0012140451,0
|
| 2739 |
-
0,0.0011769638,0
|
| 2740 |
-
0,0.00094178476,0
|
| 2741 |
-
0,0.0013351414,0
|
| 2742 |
-
0,0.0010231428,0
|
| 2743 |
-
0,0.0028764764,0
|
| 2744 |
-
0,0.00088782865,0
|
| 2745 |
-
0,0.0010239044,0
|
| 2746 |
-
0,0.0008212328,0
|
| 2747 |
-
0,0.00097123464,0
|
| 2748 |
-
0,0.0011692323,0
|
| 2749 |
-
0,0.00095932,0
|
| 2750 |
-
0,0.0009633353,0
|
| 2751 |
-
0,0.35725975,0
|
| 2752 |
-
1,0.69560957,1
|
| 2753 |
-
0,0.0010163616,0
|
| 2754 |
-
0,0.00086547283,0
|
| 2755 |
-
1,0.8560721,1
|
| 2756 |
-
0,0.00092080346,0
|
| 2757 |
-
1,0.44005793,0
|
| 2758 |
-
0,0.0009872523,0
|
| 2759 |
-
0,0.0008560454,0
|
| 2760 |
-
0,0.0012110769,0
|
| 2761 |
-
0,0.0008555859,0
|
| 2762 |
-
0,0.20835932,0
|
| 2763 |
-
0,0.0012324532,0
|
| 2764 |
-
0,0.0011165539,0
|
| 2765 |
-
1,0.73311216,1
|
| 2766 |
-
0,0.0012142817,0
|
| 2767 |
-
1,0.72214335,1
|
| 2768 |
-
0,0.0012687293,0
|
| 2769 |
-
0,0.0007941535,0
|
| 2770 |
-
0,0.052786954,0
|
| 2771 |
-
0,0.0011843917,0
|
| 2772 |
-
0,0.0008150518,0
|
| 2773 |
-
0,0.0010267626,0
|
| 2774 |
-
1,0.6452032,1
|
| 2775 |
-
0,0.3575322,0
|
| 2776 |
-
1,0.7990738,1
|
| 2777 |
-
0,0.0010931097,0
|
| 2778 |
-
0,0.0016088708,0
|
| 2779 |
-
1,0.6296911,1
|
| 2780 |
-
1,0.20946522,0
|
| 2781 |
-
0,0.0010309608,0
|
| 2782 |
-
1,0.7263279,1
|
| 2783 |
-
1,0.63071215,1
|
| 2784 |
-
0,0.013318661,0
|
| 2785 |
-
0,0.0012621472,0
|
| 2786 |
-
1,0.72783095,1
|
| 2787 |
-
0,0.0011353592,0
|
| 2788 |
-
0,0.00095678604,0
|
| 2789 |
-
0,0.001396949,0
|
| 2790 |
-
1,0.55972546,1
|
| 2791 |
-
0,0.0010440184,0
|
| 2792 |
-
0,0.27277997,0
|
| 2793 |
-
0,0.00091744936,0
|
| 2794 |
-
0,0.23269582,0
|
| 2795 |
-
0,0.0008658086,0
|
| 2796 |
-
0,0.00085602904,0
|
| 2797 |
-
0,0.0010250561,0
|
| 2798 |
-
0,0.00084345846,0
|
| 2799 |
-
1,0.66574997,1
|
| 2800 |
-
0,0.0011059438,0
|
| 2801 |
-
0,0.0011090867,0
|
| 2802 |
-
1,0.33194953,0
|
| 2803 |
-
0,0.0011119393,0
|
| 2804 |
-
0,0.6470907,1
|
| 2805 |
-
1,0.7327194,1
|
| 2806 |
-
1,0.7710486,1
|
| 2807 |
-
0,0.11899223,0
|
| 2808 |
-
0,0.36645073,0
|
| 2809 |
-
0,0.110385925,0
|
| 2810 |
-
0,0.0010350458,0
|
| 2811 |
-
1,0.7796217,1
|
| 2812 |
-
0,0.00096380076,0
|
| 2813 |
-
0,0.49681517,0
|
| 2814 |
-
0,0.0010481027,0
|
| 2815 |
-
0,0.6842804,1
|
| 2816 |
-
0,0.13682581,0
|
| 2817 |
-
0,0.05944811,0
|
| 2818 |
-
0,0.0012068587,0
|
| 2819 |
-
0,0.005288887,0
|
| 2820 |
-
0,0.65641546,1
|
| 2821 |
-
0,0.0008991473,0
|
| 2822 |
-
0,0.00090689753,0
|
| 2823 |
-
0,0.46010032,0
|
| 2824 |
-
0,0.14822635,0
|
| 2825 |
-
0,0.0011970259,0
|
| 2826 |
-
0,0.4271773,0
|
| 2827 |
-
0,0.3656186,0
|
| 2828 |
-
0,0.8195093,1
|
| 2829 |
-
1,0.5701496,1
|
| 2830 |
-
0,0.4221881,0
|
| 2831 |
-
0,0.0010203084,0
|
| 2832 |
-
1,0.74877393,1
|
| 2833 |
-
0,0.0009964638,0
|
| 2834 |
-
0,0.09252277,0
|
| 2835 |
-
0,0.32190117,0
|
| 2836 |
-
0,0.2631963,0
|
| 2837 |
-
0,0.5640367,1
|
| 2838 |
-
0,0.000990022,0
|
| 2839 |
-
1,0.20268653,0
|
| 2840 |
-
0,0.11683034,0
|
| 2841 |
-
0,0.004007565,0
|
| 2842 |
-
0,0.0017538378,0
|
| 2843 |
-
1,0.78015506,1
|
| 2844 |
-
1,0.101948835,0
|
| 2845 |
-
0,0.4630364,0
|
| 2846 |
-
0,0.0012770256,0
|
| 2847 |
-
0,0.0008438017,0
|
| 2848 |
-
0,0.0009979869,0
|
| 2849 |
-
1,0.51085657,1
|
| 2850 |
-
0,0.00088811293,0
|
| 2851 |
-
0,0.001140972,0
|
| 2852 |
-
0,0.18803759,0
|
| 2853 |
-
0,0.003006536,0
|
| 2854 |
-
0,0.0011695281,0
|
| 2855 |
-
0,0.12834576,0
|
| 2856 |
-
1,0.44561812,0
|
| 2857 |
-
0,0.0015858911,0
|
| 2858 |
-
1,0.3685857,0
|
| 2859 |
-
1,0.21091098,0
|
| 2860 |
-
0,0.4187445,0
|
| 2861 |
-
1,0.78673923,1
|
| 2862 |
-
0,0.0011236598,0
|
| 2863 |
-
0,0.0010910762,0
|
| 2864 |
-
0,0.0009057033,0
|
| 2865 |
-
0,0.0014176408,0
|
| 2866 |
-
1,0.67693305,1
|
| 2867 |
-
0,0.04034407,0
|
| 2868 |
-
0,0.0008000367,0
|
| 2869 |
-
0,0.0010251978,0
|
| 2870 |
-
0,0.001136779,0
|
| 2871 |
-
0,0.0014142994,0
|
| 2872 |
-
0,0.0008644599,0
|
| 2873 |
-
0,0.0009202785,0
|
| 2874 |
-
0,0.005384131,0
|
| 2875 |
-
0,0.00090876326,0
|
| 2876 |
-
0,0.1600891,0
|
| 2877 |
-
0,0.37699142,0
|
| 2878 |
-
0,0.0010899481,0
|
| 2879 |
-
0,0.57948345,1
|
| 2880 |
-
0,0.00081056,0
|
| 2881 |
-
0,0.0012006091,0
|
| 2882 |
-
1,0.77439034,1
|
| 2883 |
-
0,0.00091627566,0
|
| 2884 |
-
0,0.5347407,1
|
| 2885 |
-
0,0.0009845204,0
|
| 2886 |
-
1,0.4189611,0
|
| 2887 |
-
1,0.53350735,1
|
| 2888 |
-
0,0.0010007151,0
|
| 2889 |
-
0,0.0020579454,0
|
| 2890 |
-
0,0.0012446969,0
|
| 2891 |
-
0,0.000945488,0
|
| 2892 |
-
0,0.4172918,0
|
| 2893 |
-
0,0.000938053,0
|
| 2894 |
-
0,0.5179231,1
|
| 2895 |
-
0,0.0011195596,0
|
| 2896 |
-
0,0.00086770067,0
|
| 2897 |
-
1,0.7677404,1
|
| 2898 |
-
0,0.40908203,0
|
| 2899 |
-
1,0.67145306,1
|
| 2900 |
-
0,0.0009983968,0
|
| 2901 |
-
0,0.4357218,0
|
| 2902 |
-
1,0.40684062,0
|
| 2903 |
-
0,0.52875257,1
|
| 2904 |
-
0,0.0009659417,0
|
| 2905 |
-
0,0.0008981062,0
|
| 2906 |
-
0,0.0009528588,0
|
| 2907 |
-
0,0.00095595414,0
|
| 2908 |
-
1,0.56750405,1
|
| 2909 |
-
0,0.0010651433,0
|
| 2910 |
-
1,0.13752264,0
|
| 2911 |
-
0,0.00079181977,0
|
| 2912 |
-
0,0.0008729538,0
|
| 2913 |
-
0,0.0015384284,0
|
| 2914 |
-
0,0.58224595,1
|
| 2915 |
-
0,0.000829641,0
|
| 2916 |
-
1,0.45509866,0
|
| 2917 |
-
0,0.0008586016,0
|
| 2918 |
-
0,0.0008085296,0
|
| 2919 |
-
0,0.0009752612,0
|
| 2920 |
-
1,0.65100336,1
|
| 2921 |
-
0,0.0012406152,0
|
| 2922 |
-
1,0.7015345,1
|
| 2923 |
-
0,0.0008786278,0
|
| 2924 |
-
0,0.0012137983,0
|
| 2925 |
-
0,0.00094313076,0
|
| 2926 |
-
0,0.47724065,0
|
| 2927 |
-
1,0.76214045,1
|
| 2928 |
-
0,0.0012182803,0
|
| 2929 |
-
1,0.7183965,1
|
| 2930 |
-
1,0.46909162,0
|
| 2931 |
-
0,0.00082692713,0
|
| 2932 |
-
0,0.001267697,0
|
| 2933 |
-
1,0.7449686,1
|
| 2934 |
-
0,0.0009520602,0
|
| 2935 |
-
1,0.2682325,0
|
| 2936 |
-
1,0.8336221,1
|
| 2937 |
-
0,0.05860209,0
|
| 2938 |
-
1,0.5426554,1
|
| 2939 |
-
0,0.0015246021,0
|
| 2940 |
-
0,0.0009006676,0
|
| 2941 |
-
0,0.0009767053,0
|
| 2942 |
-
0,0.00081073726,0
|
| 2943 |
-
1,0.54334396,1
|
| 2944 |
-
1,0.72451127,1
|
| 2945 |
-
1,0.64862955,1
|
| 2946 |
-
0,0.0015333411,0
|
| 2947 |
-
0,0.0009270903,0
|
| 2948 |
-
0,0.0008060023,0
|
| 2949 |
-
0,0.0015449867,0
|
| 2950 |
-
0,0.0008756904,0
|
| 2951 |
-
0,0.00089255895,0
|
| 2952 |
-
1,0.4240961,0
|
| 2953 |
-
0,0.0007905915,0
|
| 2954 |
-
1,0.6176073,1
|
| 2955 |
-
0,0.16892567,0
|
| 2956 |
-
0,0.0019416468,0
|
| 2957 |
-
0,0.0034173308,0
|
| 2958 |
-
1,0.48232934,0
|
| 2959 |
-
0,0.3651226,0
|
| 2960 |
-
1,0.5432746,1
|
| 2961 |
-
1,0.53853256,1
|
| 2962 |
-
0,0.049414482,0
|
| 2963 |
-
1,0.61953956,1
|
| 2964 |
-
0,0.00086147717,0
|
| 2965 |
-
0,0.5655044,1
|
| 2966 |
-
1,0.5725881,1
|
| 2967 |
-
0,0.38828382,0
|
| 2968 |
-
0,0.0008593283,0
|
| 2969 |
-
0,0.6609687,1
|
| 2970 |
-
0,0.0031328383,0
|
| 2971 |
-
0,0.0016064071,0
|
| 2972 |
-
0,0.16550626,0
|
| 2973 |
-
0,0.0010093974,0
|
| 2974 |
-
1,0.78001803,1
|
| 2975 |
-
0,0.0009410012,0
|
| 2976 |
-
0,0.019389188,0
|
| 2977 |
-
1,0.8118031,1
|
| 2978 |
-
0,0.0010585715,0
|
| 2979 |
-
0,0.0011178768,0
|
| 2980 |
-
0,0.51671094,1
|
| 2981 |
-
1,0.14671549,0
|
| 2982 |
-
0,0.0012277119,0
|
| 2983 |
-
0,0.10688765,0
|
| 2984 |
-
0,0.26730445,0
|
| 2985 |
-
1,0.41755715,0
|
| 2986 |
-
0,0.6532114,1
|
| 2987 |
-
1,0.5408069,1
|
| 2988 |
-
0,0.0011035037,0
|
| 2989 |
-
0,0.38721368,0
|
| 2990 |
-
0,0.0009238639,0
|
| 2991 |
-
0,0.00087872625,0
|
| 2992 |
-
1,0.4663372,0
|
| 2993 |
-
0,0.25327864,0
|
| 2994 |
-
0,0.0011890574,0
|
| 2995 |
-
0,0.08030304,0
|
| 2996 |
-
0,0.0012005063,0
|
| 2997 |
-
0,0.0008608408,0
|
| 2998 |
-
1,0.5602192,1
|
| 2999 |
-
1,0.7817263,1
|
| 3000 |
-
0,0.00083466305,0
|
| 3001 |
-
0,0.0010471261,0
|
| 3002 |
-
0,0.001124526,0
|
| 3003 |
-
0,0.0012333503,0
|
| 3004 |
-
0,0.012108352,0
|
| 3005 |
-
0,0.21776897,0
|
| 3006 |
-
0,0.0016868615,0
|
| 3007 |
-
1,0.82101417,1
|
| 3008 |
-
0,0.0009039936,0
|
| 3009 |
-
0,0.45584226,0
|
| 3010 |
-
0,0.45311502,0
|
| 3011 |
-
0,0.00096501835,0
|
| 3012 |
-
1,0.66155726,1
|
| 3013 |
-
0,0.0008598758,0
|
| 3014 |
-
0,0.49483657,0
|
| 3015 |
-
0,0.0012214681,0
|
| 3016 |
-
0,0.0012331072,0
|
| 3017 |
-
0,0.0007621751,0
|
| 3018 |
-
0,0.0063888626,0
|
| 3019 |
-
0,0.0009407718,0
|
| 3020 |
-
0,0.0008379926,0
|
| 3021 |
-
0,0.0012965859,0
|
| 3022 |
-
1,0.5956936,1
|
| 3023 |
-
1,0.27054062,0
|
| 3024 |
-
0,0.3623484,0
|
| 3025 |
-
1,0.7357546,1
|
| 3026 |
-
0,0.00154095,0
|
| 3027 |
-
0,0.003388907,0
|
| 3028 |
-
0,0.0011851212,0
|
| 3029 |
-
0,0.0010015048,0
|
| 3030 |
-
0,0.0009875342,0
|
| 3031 |
-
0,0.56412435,1
|
| 3032 |
-
0,0.17703854,0
|
| 3033 |
-
0,0.0049527115,0
|
| 3034 |
-
0,0.41469368,0
|
| 3035 |
-
0,0.0013113499,0
|
| 3036 |
-
0,0.25928286,0
|
| 3037 |
-
0,0.00080234197,0
|
| 3038 |
-
0,0.0011141081,0
|
| 3039 |
-
0,0.0009273143,0
|
| 3040 |
-
0,0.15817499,0
|
| 3041 |
-
1,0.55791485,1
|
| 3042 |
-
0,0.49433956,0
|
| 3043 |
-
0,0.44379577,0
|
| 3044 |
-
0,0.00097380485,0
|
| 3045 |
-
0,0.00081098836,0
|
| 3046 |
-
0,0.09645127,0
|
| 3047 |
-
0,0.0010052086,0
|
| 3048 |
-
0,0.0008312449,0
|
| 3049 |
-
0,0.0014976384,0
|
| 3050 |
-
0,0.7013783,1
|
| 3051 |
-
0,0.7363814,1
|
| 3052 |
-
0,0.0017869734,0
|
| 3053 |
-
0,0.5600073,1
|
| 3054 |
-
1,0.5933672,1
|
| 3055 |
-
1,0.7463558,1
|
| 3056 |
-
0,0.00087595376,0
|
| 3057 |
-
0,0.3700118,0
|
| 3058 |
-
0,0.4078047,0
|
| 3059 |
-
1,0.70503634,1
|
| 3060 |
-
0,0.00087403157,0
|
| 3061 |
-
1,0.59753084,1
|
| 3062 |
-
1,0.28408647,0
|
| 3063 |
-
1,0.346187,0
|
| 3064 |
-
0,0.0010321912,0
|
| 3065 |
-
0,0.0010446163,0
|
| 3066 |
-
0,0.0019384956,0
|
| 3067 |
-
0,0.2687528,0
|
| 3068 |
-
0,0.002051823,0
|
| 3069 |
-
0,0.00088337995,0
|
| 3070 |
-
0,0.0009888418,0
|
| 3071 |
-
1,0.518932,1
|
| 3072 |
-
0,0.0014419496,0
|
| 3073 |
-
0,0.28866056,0
|
| 3074 |
-
0,0.0018917412,0
|
| 3075 |
-
0,0.7118644,1
|
| 3076 |
-
0,0.0007667698,0
|
| 3077 |
-
0,0.0011144138,0
|
| 3078 |
-
0,0.0010100271,0
|
| 3079 |
-
0,0.0023653419,0
|
| 3080 |
-
1,0.5826993,1
|
| 3081 |
-
1,0.2670869,0
|
| 3082 |
-
0,0.0012896777,0
|
| 3083 |
-
0,0.00094871124,0
|
| 3084 |
-
0,0.0011521467,0
|
| 3085 |
-
0,0.0010139366,0
|
| 3086 |
-
0,0.081099555,0
|
| 3087 |
-
0,0.0012478646,0
|
| 3088 |
-
0,0.0009908755,0
|
| 3089 |
-
0,0.00085459027,0
|
| 3090 |
-
1,0.74782383,1
|
| 3091 |
-
1,0.63217145,1
|
| 3092 |
-
0,0.001039582,0
|
| 3093 |
-
0,0.0010658543,0
|
| 3094 |
-
0,0.0015608877,0
|
| 3095 |
-
0,0.00084603427,0
|
| 3096 |
-
1,0.69931394,1
|
| 3097 |
-
1,0.6545752,1
|
| 3098 |
-
0,0.0012512095,0
|
| 3099 |
-
0,0.25088048,0
|
| 3100 |
-
1,0.46437111,0
|
| 3101 |
-
0,0.00084795104,0
|
| 3102 |
-
0,0.0008775966,0
|
| 3103 |
-
0,0.0014866656,0
|
| 3104 |
-
0,0.0010228608,0
|
| 3105 |
-
1,0.7594941,1
|
| 3106 |
-
0,0.0010856481,0
|
| 3107 |
-
1,0.6442902,1
|
| 3108 |
-
0,0.0011075162,0
|
| 3109 |
-
1,0.7875049,1
|
| 3110 |
-
0,0.0009568744,0
|
| 3111 |
-
1,0.8581466,1
|
| 3112 |
-
0,0.00096314953,0
|
| 3113 |
-
0,0.0009732034,0
|
| 3114 |
-
0,0.0011903605,0
|
| 3115 |
-
0,0.000776502,0
|
| 3116 |
-
1,0.7685739,1
|
| 3117 |
-
0,0.0008151252,0
|
| 3118 |
-
0,0.001778643,0
|
| 3119 |
-
1,0.74879175,1
|
| 3120 |
-
0,0.20462619,0
|
| 3121 |
-
0,0.0010626859,0
|
| 3122 |
-
0,0.00089822855,0
|
| 3123 |
-
1,0.30061033,0
|
| 3124 |
-
0,0.20656452,0
|
| 3125 |
-
0,0.001094946,0
|
| 3126 |
-
1,0.7288647,1
|
| 3127 |
-
0,0.00084817404,0
|
| 3128 |
-
1,0.7852595,1
|
| 3129 |
-
0,0.0009366696,0
|
| 3130 |
-
1,0.75811154,1
|
| 3131 |
-
0,0.43617013,0
|
| 3132 |
-
0,0.0009005994,0
|
| 3133 |
-
0,0.0008882077,0
|
| 3134 |
-
0,0.0008365301,0
|
| 3135 |
-
1,0.3134185,0
|
| 3136 |
-
0,0.00086332444,0
|
| 3137 |
-
0,0.2912862,0
|
| 3138 |
-
0,0.0009431514,0
|
| 3139 |
-
0,0.0011384224,0
|
| 3140 |
-
1,0.7892175,1
|
| 3141 |
-
0,0.100751415,0
|
| 3142 |
-
0,0.00085538207,0
|
| 3143 |
-
0,0.00081628014,0
|
| 3144 |
-
0,0.14945064,0
|
| 3145 |
-
0,0.0011895921,0
|
| 3146 |
-
0,0.0008453175,0
|
| 3147 |
-
1,0.7127246,1
|
| 3148 |
-
0,0.0014821741,0
|
| 3149 |
-
0,0.48956093,0
|
| 3150 |
-
1,0.6526354,1
|
| 3151 |
-
0,0.70968366,1
|
| 3152 |
-
0,0.0013969796,0
|
| 3153 |
-
0,0.15343508,0
|
| 3154 |
-
0,0.0011304358,0
|
| 3155 |
-
1,0.69095427,1
|
| 3156 |
-
0,0.0013296176,0
|
| 3157 |
-
0,0.0012575921,0
|
| 3158 |
-
0,0.00092597713,0
|
| 3159 |
-
0,0.0008999059,0
|
| 3160 |
-
0,0.24098817,0
|
| 3161 |
-
0,0.0016370128,0
|
| 3162 |
-
0,0.0012018238,0
|
| 3163 |
-
0,0.2096467,0
|
| 3164 |
-
1,0.86404353,1
|
| 3165 |
-
0,0.022236077,0
|
| 3166 |
-
0,0.0010826504,0
|
| 3167 |
-
0,0.022941377,0
|
| 3168 |
-
1,0.7931688,1
|
| 3169 |
-
0,0.001005007,0
|
| 3170 |
-
0,0.0014776052,0
|
| 3171 |
-
0,0.007663459,0
|
| 3172 |
-
0,0.42009014,0
|
| 3173 |
-
0,0.0010014976,0
|
| 3174 |
-
0,0.0014163717,0
|
| 3175 |
-
1,0.28360623,0
|
| 3176 |
-
0,0.002345287,0
|
| 3177 |
-
0,0.0010593463,0
|
| 3178 |
-
0,0.14843744,0
|
| 3179 |
-
0,0.31841034,0
|
| 3180 |
-
0,0.0011871555,0
|
| 3181 |
-
1,0.58255094,1
|
| 3182 |
-
0,0.00079310685,0
|
| 3183 |
-
0,0.0015736327,0
|
| 3184 |
-
1,0.42560086,0
|
| 3185 |
-
1,0.79392105,1
|
| 3186 |
-
1,0.8037611,1
|
| 3187 |
-
0,0.15339166,0
|
| 3188 |
-
0,0.44212908,0
|
| 3189 |
-
0,0.15695691,0
|
| 3190 |
-
0,0.001069759,0
|
| 3191 |
-
0,0.0013624149,0
|
| 3192 |
-
0,0.10386715,0
|
| 3193 |
-
1,0.38520056,0
|
| 3194 |
-
0,0.00083144975,0
|
| 3195 |
-
1,0.51295435,1
|
| 3196 |
-
0,0.3737823,0
|
| 3197 |
-
0,0.00096323073,0
|
| 3198 |
-
0,0.0011089299,0
|
| 3199 |
-
0,0.00083140214,0
|
| 3200 |
-
0,0.001142048,0
|
| 3201 |
-
1,0.39718005,0
|
| 3202 |
-
0,0.6380984,1
|
| 3203 |
-
0,0.12990019,0
|
| 3204 |
-
0,0.6759345,1
|
| 3205 |
-
0,0.000828972,0
|
| 3206 |
-
0,0.0065064146,0
|
| 3207 |
-
0,0.0010996222,0
|
| 3208 |
-
0,0.5806185,1
|
| 3209 |
-
0,0.0030489217,0
|
| 3210 |
-
0,0.0009842621,0
|
| 3211 |
-
0,0.00096107053,0
|
| 3212 |
-
1,0.68338144,1
|
| 3213 |
-
0,0.0009416806,0
|
| 3214 |
-
0,0.001178365,0
|
| 3215 |
-
0,0.0016215707,0
|
| 3216 |
-
0,0.41308418,0
|
| 3217 |
-
0,0.37985086,0
|
| 3218 |
-
0,0.0010058667,0
|
| 3219 |
-
1,0.27062395,0
|
| 3220 |
-
0,0.0013959251,0
|
| 3221 |
-
1,0.6292908,1
|
| 3222 |
-
0,0.0013573767,0
|
| 3223 |
-
0,0.0028528608,0
|
| 3224 |
-
1,0.7515089,1
|
| 3225 |
-
1,0.58315617,1
|
| 3226 |
-
1,0.29764894,0
|
| 3227 |
-
0,0.0010242697,0
|
| 3228 |
-
0,0.66869485,1
|
| 3229 |
-
1,0.34468064,0
|
| 3230 |
-
0,0.0009077599,0
|
| 3231 |
-
0,0.0012983766,0
|
| 3232 |
-
1,0.6648164,1
|
| 3233 |
-
0,0.00096312474,0
|
| 3234 |
-
0,0.22477408,0
|
| 3235 |
-
0,0.0010165646,0
|
| 3236 |
-
0,0.0012150696,0
|
| 3237 |
-
0,0.001145921,0
|
| 3238 |
-
0,0.0008946433,0
|
| 3239 |
-
0,0.0009847601,0
|
| 3240 |
-
0,0.0019780556,0
|
| 3241 |
-
0,0.03994404,0
|
| 3242 |
-
0,0.39725897,0
|
| 3243 |
-
0,0.0011280741,0
|
| 3244 |
-
0,0.14372922,0
|
| 3245 |
-
0,0.00097693,0
|
| 3246 |
-
0,0.49378365,0
|
| 3247 |
-
0,0.0008469307,0
|
| 3248 |
-
1,0.23917605,0
|
| 3249 |
-
0,0.110021144,0
|
| 3250 |
-
0,0.00080322206,0
|
| 3251 |
-
0,0.0008781373,0
|
| 3252 |
-
0,0.002445047,0
|
| 3253 |
-
0,0.00089050585,0
|
| 3254 |
-
0,0.0012529906,0
|
| 3255 |
-
0,0.0012119152,0
|
| 3256 |
-
0,0.0016275514,0
|
| 3257 |
-
1,0.6807688,1
|
| 3258 |
-
0,0.0013215955,0
|
| 3259 |
-
0,0.78445035,1
|
| 3260 |
-
0,0.0008880186,0
|
| 3261 |
-
0,0.001284776,0
|
| 3262 |
-
0,0.0017074918,0
|
| 3263 |
-
0,0.0009681374,0
|
| 3264 |
-
0,0.6296154,1
|
| 3265 |
-
0,0.001095009,0
|
| 3266 |
-
0,0.0015562188,0
|
| 3267 |
-
1,0.6771481,1
|
| 3268 |
-
1,0.18929906,0
|
| 3269 |
-
0,0.0008950543,0
|
| 3270 |
-
0,0.0008272061,0
|
| 3271 |
-
0,0.0017716148,0
|
| 3272 |
-
1,0.76254493,1
|
| 3273 |
-
0,0.0009578705,0
|
| 3274 |
-
1,0.8615882,1
|
| 3275 |
-
0,0.00092509435,0
|
| 3276 |
-
0,0.0009841388,0
|
| 3277 |
-
0,0.0008676494,0
|
| 3278 |
-
0,0.0014497808,0
|
| 3279 |
-
0,0.0010180256,0
|
| 3280 |
-
0,0.0020955966,0
|
| 3281 |
-
0,0.0016175646,0
|
| 3282 |
-
0,0.0008392496,0
|
| 3283 |
-
1,0.6555392,1
|
| 3284 |
-
0,0.00096220616,0
|
| 3285 |
-
0,0.0012978838,0
|
| 3286 |
-
0,0.0020624246,0
|
| 3287 |
-
0,0.32421842,0
|
| 3288 |
-
0,0.0012323145,0
|
| 3289 |
-
1,0.15064883,0
|
| 3290 |
-
1,0.2249341,0
|
| 3291 |
-
1,0.6194747,1
|
| 3292 |
-
1,0.40883264,0
|
| 3293 |
-
0,0.00088793685,0
|
| 3294 |
-
0,0.5703692,1
|
| 3295 |
-
0,0.00080029253,0
|
| 3296 |
-
1,0.533915,1
|
| 3297 |
-
0,0.0018578873,0
|
| 3298 |
-
0,0.0009043615,0
|
| 3299 |
-
0,0.61236954,1
|
| 3300 |
-
1,0.25842515,0
|
| 3301 |
-
1,0.6920323,1
|
| 3302 |
-
0,0.0009899612,0
|
| 3303 |
-
1,0.21387953,0
|
| 3304 |
-
0,0.0023066083,0
|
| 3305 |
-
1,0.6479849,1
|
| 3306 |
-
0,0.0011913775,0
|
| 3307 |
-
0,0.0010240928,0
|
| 3308 |
-
0,0.00094176497,0
|
| 3309 |
-
0,0.2934143,0
|
| 3310 |
-
1,0.32174513,0
|
| 3311 |
-
0,0.00091399293,0
|
| 3312 |
-
0,0.3778847,0
|
| 3313 |
-
0,0.0010437181,0
|
| 3314 |
-
0,0.00093027006,0
|
| 3315 |
-
1,0.2693971,0
|
| 3316 |
-
1,0.695859,1
|
| 3317 |
-
0,0.0009045273,0
|
| 3318 |
-
0,0.21045512,0
|
| 3319 |
-
0,0.0008732916,0
|
| 3320 |
-
1,0.615637,1
|
| 3321 |
-
0,0.001324741,0
|
| 3322 |
-
0,0.00093508593,0
|
| 3323 |
-
0,0.00074671215,0
|
| 3324 |
-
0,0.33646423,0
|
| 3325 |
-
0,0.001223517,0
|
| 3326 |
-
0,0.34300792,0
|
| 3327 |
-
0,0.0029657178,0
|
| 3328 |
-
0,0.0010389285,0
|
| 3329 |
-
0,0.00087984215,0
|
| 3330 |
-
0,0.0009522997,0
|
| 3331 |
-
0,0.0009270687,0
|
| 3332 |
-
0,0.0010787433,0
|
| 3333 |
-
0,0.001082428,0
|
| 3334 |
-
0,0.0010898644,0
|
| 3335 |
-
0,0.0015669015,0
|
| 3336 |
-
0,0.50207824,1
|
| 3337 |
-
1,0.69201577,1
|
| 3338 |
-
0,0.0027307745,0
|
| 3339 |
-
1,0.65933526,1
|
| 3340 |
-
1,0.33638108,0
|
| 3341 |
-
0,0.0009279095,0
|
| 3342 |
-
1,0.6811094,1
|
| 3343 |
-
1,0.6222203,1
|
| 3344 |
-
0,0.00087204383,0
|
| 3345 |
-
0,0.001260481,0
|
| 3346 |
-
0,0.221174,0
|
| 3347 |
-
0,0.000790207,0
|
| 3348 |
-
0,0.000836068,0
|
| 3349 |
-
0,0.34802088,0
|
| 3350 |
-
0,0.0011530793,0
|
| 3351 |
-
0,0.0008934312,0
|
| 3352 |
-
0,0.000901762,0
|
| 3353 |
-
0,0.0068717427,0
|
| 3354 |
-
0,0.0012540269,0
|
| 3355 |
-
0,0.0008931409,0
|
| 3356 |
-
0,0.24012849,0
|
| 3357 |
-
0,0.00096857373,0
|
| 3358 |
-
0,0.0009792737,0
|
| 3359 |
-
1,0.73453856,1
|
| 3360 |
-
1,0.71891254,1
|
| 3361 |
-
0,0.0012205226,0
|
| 3362 |
-
1,0.5815703,1
|
| 3363 |
-
0,0.0017512442,0
|
| 3364 |
-
0,0.0008828233,0
|
| 3365 |
-
0,0.0039220187,0
|
| 3366 |
-
0,0.0013681627,0
|
| 3367 |
-
0,0.85258,1
|
| 3368 |
-
0,0.0010186047,0
|
| 3369 |
-
1,0.8146731,1
|
| 3370 |
-
0,0.05123307,0
|
| 3371 |
-
1,0.6553712,1
|
| 3372 |
-
0,0.0016004561,0
|
| 3373 |
-
0,0.12838942,0
|
| 3374 |
-
0,0.0009933491,0
|
| 3375 |
-
0,0.0009402028,0
|
| 3376 |
-
0,0.000918758,0
|
| 3377 |
-
1,0.8279392,1
|
| 3378 |
-
1,0.35954273,0
|
| 3379 |
-
1,0.76767814,1
|
| 3380 |
-
1,0.48393753,0
|
| 3381 |
-
1,0.53692836,1
|
| 3382 |
-
0,0.0008827944,0
|
| 3383 |
-
0,0.0009095459,0
|
| 3384 |
-
1,0.6520005,1
|
| 3385 |
-
0,0.0010726445,0
|
| 3386 |
-
0,0.0013205588,0
|
| 3387 |
-
0,0.0012526136,0
|
| 3388 |
-
1,0.518195,1
|
| 3389 |
-
0,0.48449913,0
|
| 3390 |
-
0,0.4973757,0
|
| 3391 |
-
0,0.0008689781,0
|
| 3392 |
-
1,0.6955714,1
|
| 3393 |
-
0,0.0027496573,0
|
| 3394 |
-
0,0.70566237,1
|
| 3395 |
-
0,0.0008412664,0
|
| 3396 |
-
0,0.0013987075,0
|
| 3397 |
-
1,0.13971777,0
|
| 3398 |
-
0,0.0012735971,0
|
| 3399 |
-
0,0.42210418,0
|
| 3400 |
-
1,0.64054,1
|
| 3401 |
-
0,0.0008632639,0
|
| 3402 |
-
0,0.0008357673,0
|
| 3403 |
-
0,0.0010104679,0
|
| 3404 |
-
0,0.0011264505,0
|
| 3405 |
-
0,0.0008627759,0
|
| 3406 |
-
1,0.63967973,1
|
| 3407 |
-
0,0.0015104539,0
|
| 3408 |
-
0,0.7630441,1
|
| 3409 |
-
0,0.25452375,0
|
| 3410 |
-
0,0.0008738256,0
|
| 3411 |
-
0,0.0010032041,0
|
| 3412 |
-
0,0.5034312,1
|
| 3413 |
-
0,0.00092107285,0
|
| 3414 |
-
0,0.0013620523,0
|
| 3415 |
-
1,0.91997075,1
|
| 3416 |
-
1,0.52912873,1
|
| 3417 |
-
0,0.0024464617,0
|
| 3418 |
-
0,0.72202426,1
|
| 3419 |
-
0,0.12549123,0
|
| 3420 |
-
1,0.3554386,0
|
| 3421 |
-
1,0.68960494,1
|
| 3422 |
-
0,0.0011416401,0
|
| 3423 |
-
0,0.062529385,0
|
| 3424 |
-
0,0.00089203403,0
|
| 3425 |
-
0,0.16887772,0
|
| 3426 |
-
0,0.00095955486,0
|
| 3427 |
-
1,0.350305,0
|
| 3428 |
-
0,0.12311296,0
|
| 3429 |
-
0,0.0024356688,0
|
| 3430 |
-
0,0.5414857,1
|
| 3431 |
-
0,0.0008605533,0
|
| 3432 |
-
1,0.7618979,1
|
| 3433 |
-
0,0.0009157459,0
|
| 3434 |
-
0,0.0010139028,0
|
| 3435 |
-
0,0.33930495,0
|
| 3436 |
-
0,0.0007996137,0
|
| 3437 |
-
0,0.0010549367,0
|
| 3438 |
-
0,0.003367107,0
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ef9845cf64c142ff16fc915402953a1383e36ecb1c76b6174fae75c0dec59cd4
|
| 3 |
+
size 54904
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer/.ipynb_checkpoints/my_tokenizers-checkpoint.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import codecs
|
| 6 |
+
import unicodedata
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
from transformers import PreTrainedTokenizer
|
| 9 |
+
from SmilesPE.tokenizer import SPE_Tokenizer
|
| 10 |
+
|
| 11 |
+
def load_vocab(vocab_file):
|
| 12 |
+
"""Loads a vocabulary file into a dictionary."""
|
| 13 |
+
vocab = collections.OrderedDict()
|
| 14 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 15 |
+
tokens = reader.readlines()
|
| 16 |
+
for index, token in enumerate(tokens):
|
| 17 |
+
token = token.rstrip("\n")
|
| 18 |
+
vocab[token] = index
|
| 19 |
+
return vocab
|
| 20 |
+
|
| 21 |
+
class Atomwise_Tokenizer(object):
|
| 22 |
+
"""Run atom-level SMILES tokenization"""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
""" Constructs a atom-level Tokenizer.
|
| 26 |
+
"""
|
| 27 |
+
# self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 28 |
+
self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 29 |
+
|
| 30 |
+
self.regex = re.compile(self.regex_pattern)
|
| 31 |
+
|
| 32 |
+
def tokenize(self, text):
|
| 33 |
+
""" Basic Tokenization of a SMILES.
|
| 34 |
+
"""
|
| 35 |
+
tokens = [token for token in self.regex.findall(text)]
|
| 36 |
+
return tokens
|
| 37 |
+
|
| 38 |
+
class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
|
| 39 |
+
r"""
|
| 40 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 41 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 42 |
+
should refer to the superclass for more information regarding methods.
|
| 43 |
+
Args:
|
| 44 |
+
vocab_file (:obj:`string`):
|
| 45 |
+
File containing the vocabulary.
|
| 46 |
+
spe_file (:obj:`string`):
|
| 47 |
+
File containing the trained SMILES Pair Encoding vocabulary.
|
| 48 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 49 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 50 |
+
token instead.
|
| 51 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 52 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 53 |
+
for sequence classification or for a text and a question for question answering.
|
| 54 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 55 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 56 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 57 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 58 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 59 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 60 |
+
special tokens.
|
| 61 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 62 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 63 |
+
modeling. This is the token which the model will try to predict.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, vocab_file, spe_file,
|
| 67 |
+
unk_token="[UNK]",
|
| 68 |
+
sep_token="[SEP]",
|
| 69 |
+
pad_token="[PAD]",
|
| 70 |
+
cls_token="[CLS]",
|
| 71 |
+
mask_token="[MASK]",
|
| 72 |
+
**kwargs):
|
| 73 |
+
if not os.path.isfile(vocab_file):
|
| 74 |
+
raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
|
| 75 |
+
if not os.path.isfile(spe_file):
|
| 76 |
+
raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
|
| 77 |
+
|
| 78 |
+
self.vocab = load_vocab(vocab_file)
|
| 79 |
+
self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
|
| 80 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 81 |
+
self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
|
| 82 |
+
|
| 83 |
+
super().__init__(
|
| 84 |
+
unk_token=unk_token,
|
| 85 |
+
sep_token=sep_token,
|
| 86 |
+
pad_token=pad_token,
|
| 87 |
+
cls_token=cls_token,
|
| 88 |
+
mask_token=mask_token,
|
| 89 |
+
**kwargs)
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def vocab_size(self):
|
| 93 |
+
return len(self.vocab)
|
| 94 |
+
|
| 95 |
+
def get_vocab(self):
|
| 96 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 97 |
+
|
| 98 |
+
def _tokenize(self, text):
|
| 99 |
+
return self.spe_tokenizer.tokenize(text).split(' ')
|
| 100 |
+
|
| 101 |
+
def _convert_token_to_id(self, token):
|
| 102 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 103 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 104 |
+
|
| 105 |
+
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
| 106 |
+
text = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
| 107 |
+
return self.convert_tokens_to_string(text)
|
| 108 |
+
|
| 109 |
+
def _convert_id_to_token(self, index):
|
| 110 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 111 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 112 |
+
|
| 113 |
+
def convert_tokens_to_string(self, tokens):
|
| 114 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 115 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 116 |
+
return out_string
|
| 117 |
+
|
| 118 |
+
def build_inputs_with_special_tokens(
|
| 119 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 120 |
+
) -> List[int]:
|
| 121 |
+
"""
|
| 122 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 123 |
+
by concatenating and adding special tokens.
|
| 124 |
+
A BERT sequence has the following format:
|
| 125 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 126 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 127 |
+
Args:
|
| 128 |
+
token_ids_0 (:obj:`List[int]`):
|
| 129 |
+
List of IDs to which the special tokens will be added
|
| 130 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 131 |
+
Optional second list of IDs for sequence pairs.
|
| 132 |
+
Returns:
|
| 133 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 134 |
+
"""
|
| 135 |
+
if token_ids_1 is None:
|
| 136 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 137 |
+
cls = [self.cls_token_id]
|
| 138 |
+
sep = [self.sep_token_id]
|
| 139 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 140 |
+
|
| 141 |
+
def get_special_tokens_mask(
|
| 142 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 143 |
+
) -> List[int]:
|
| 144 |
+
"""
|
| 145 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 146 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 147 |
+
Args:
|
| 148 |
+
token_ids_0 (:obj:`List[int]`):
|
| 149 |
+
List of ids.
|
| 150 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 151 |
+
Optional second list of IDs for sequence pairs.
|
| 152 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 153 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 154 |
+
Returns:
|
| 155 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
if already_has_special_tokens:
|
| 159 |
+
if token_ids_1 is not None:
|
| 160 |
+
raise ValueError(
|
| 161 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 162 |
+
"ids is already formated with special tokens for the model."
|
| 163 |
+
)
|
| 164 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 165 |
+
|
| 166 |
+
if token_ids_1 is not None:
|
| 167 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 168 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 169 |
+
|
| 170 |
+
def create_token_type_ids_from_sequences(
|
| 171 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 172 |
+
) -> List[int]:
|
| 173 |
+
"""
|
| 174 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 175 |
+
A BERT sequence pair mask has the following format:
|
| 176 |
+
::
|
| 177 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 178 |
+
| first sequence | second sequence |
|
| 179 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 180 |
+
Args:
|
| 181 |
+
token_ids_0 (:obj:`List[int]`):
|
| 182 |
+
List of ids.
|
| 183 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 184 |
+
Optional second list of IDs for sequence pairs.
|
| 185 |
+
Returns:
|
| 186 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 187 |
+
sequence(s).
|
| 188 |
+
"""
|
| 189 |
+
sep = [self.sep_token_id]
|
| 190 |
+
cls = [self.cls_token_id]
|
| 191 |
+
if token_ids_1 is None:
|
| 192 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 193 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 194 |
+
|
| 195 |
+
def save_vocabulary(self, vocab_path):
|
| 196 |
+
"""
|
| 197 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 198 |
+
Args:
|
| 199 |
+
vocab_path (:obj:`str`):
|
| 200 |
+
The directory in which to save the vocabulary.
|
| 201 |
+
Returns:
|
| 202 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 203 |
+
"""
|
| 204 |
+
index = 0
|
| 205 |
+
if os.path.isdir(vocab_path):
|
| 206 |
+
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
| 207 |
+
else:
|
| 208 |
+
vocab_file = vocab_path
|
| 209 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 210 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 211 |
+
if index != token_index:
|
| 212 |
+
logger.warning(
|
| 213 |
+
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
| 214 |
+
" Please check that the vocabulary is not corrupted!".format(vocab_file)
|
| 215 |
+
)
|
| 216 |
+
index = token_index
|
| 217 |
+
writer.write(token + "\n")
|
| 218 |
+
index += 1
|
| 219 |
+
return (vocab_file,)
|
| 220 |
+
|
| 221 |
+
class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
|
| 222 |
+
r"""
|
| 223 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 224 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 225 |
+
should refer to the superclass for more information regarding methods.
|
| 226 |
+
Args:
|
| 227 |
+
vocab_file (:obj:`string`):
|
| 228 |
+
File containing the vocabulary.
|
| 229 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 230 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 231 |
+
token instead.
|
| 232 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 233 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 234 |
+
for sequence classification or for a text and a question for question answering.
|
| 235 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 236 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 237 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 238 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 239 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 240 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 241 |
+
special tokens.
|
| 242 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 243 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 244 |
+
modeling. This is the token which the model will try to predict.
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def __init__(
|
| 248 |
+
self,
|
| 249 |
+
vocab_file,
|
| 250 |
+
unk_token="[UNK]",
|
| 251 |
+
sep_token="[SEP]",
|
| 252 |
+
pad_token="[PAD]",
|
| 253 |
+
cls_token="[CLS]",
|
| 254 |
+
mask_token="[MASK]",
|
| 255 |
+
**kwargs
|
| 256 |
+
):
|
| 257 |
+
super().__init__(
|
| 258 |
+
unk_token=unk_token,
|
| 259 |
+
sep_token=sep_token,
|
| 260 |
+
pad_token=pad_token,
|
| 261 |
+
cls_token=cls_token,
|
| 262 |
+
mask_token=mask_token,
|
| 263 |
+
**kwargs,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
if not os.path.isfile(vocab_file):
|
| 267 |
+
raise ValueError(
|
| 268 |
+
"Can't find a vocabulary file at path '{}'.".format(vocab_file)
|
| 269 |
+
)
|
| 270 |
+
self.vocab = load_vocab(vocab_file)
|
| 271 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 272 |
+
self.tokenizer = Atomwise_Tokenizer()
|
| 273 |
+
|
| 274 |
+
@property
|
| 275 |
+
def vocab_size(self):
|
| 276 |
+
return len(self.vocab)
|
| 277 |
+
|
| 278 |
+
def get_vocab(self):
|
| 279 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 280 |
+
|
| 281 |
+
def _tokenize(self, text):
|
| 282 |
+
return self.tokenizer.tokenize(text)
|
| 283 |
+
|
| 284 |
+
def _convert_token_to_id(self, token):
|
| 285 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 286 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 287 |
+
|
| 288 |
+
def _convert_id_to_token(self, index):
|
| 289 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 290 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 291 |
+
|
| 292 |
+
def convert_tokens_to_string(self, tokens):
|
| 293 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 294 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 295 |
+
return out_string
|
| 296 |
+
|
| 297 |
+
def build_inputs_with_special_tokens(
|
| 298 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 299 |
+
) -> List[int]:
|
| 300 |
+
"""
|
| 301 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 302 |
+
by concatenating and adding special tokens.
|
| 303 |
+
A BERT sequence has the following format:
|
| 304 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 305 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 306 |
+
Args:
|
| 307 |
+
token_ids_0 (:obj:`List[int]`):
|
| 308 |
+
List of IDs to which the special tokens will be added
|
| 309 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 310 |
+
Optional second list of IDs for sequence pairs.
|
| 311 |
+
Returns:
|
| 312 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 313 |
+
"""
|
| 314 |
+
if token_ids_1 is None:
|
| 315 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 316 |
+
cls = [self.cls_token_id]
|
| 317 |
+
sep = [self.sep_token_id]
|
| 318 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 319 |
+
|
| 320 |
+
def get_special_tokens_mask(
|
| 321 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 322 |
+
) -> List[int]:
|
| 323 |
+
"""
|
| 324 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 325 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 326 |
+
Args:
|
| 327 |
+
token_ids_0 (:obj:`List[int]`):
|
| 328 |
+
List of ids.
|
| 329 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 330 |
+
Optional second list of IDs for sequence pairs.
|
| 331 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 332 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 333 |
+
Returns:
|
| 334 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
if already_has_special_tokens:
|
| 338 |
+
if token_ids_1 is not None:
|
| 339 |
+
raise ValueError(
|
| 340 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 341 |
+
"ids is already formated with special tokens for the model."
|
| 342 |
+
)
|
| 343 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 344 |
+
|
| 345 |
+
if token_ids_1 is not None:
|
| 346 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 347 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 348 |
+
|
| 349 |
+
def create_token_type_ids_from_sequences(
|
| 350 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 351 |
+
) -> List[int]:
|
| 352 |
+
"""
|
| 353 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 354 |
+
A BERT sequence pair mask has the following format:
|
| 355 |
+
::
|
| 356 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 357 |
+
| first sequence | second sequence |
|
| 358 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 359 |
+
Args:
|
| 360 |
+
token_ids_0 (:obj:`List[int]`):
|
| 361 |
+
List of ids.
|
| 362 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 363 |
+
Optional second list of IDs for sequence pairs.
|
| 364 |
+
Returns:
|
| 365 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 366 |
+
sequence(s).
|
| 367 |
+
"""
|
| 368 |
+
sep = [self.sep_token_id]
|
| 369 |
+
cls = [self.cls_token_id]
|
| 370 |
+
if token_ids_1 is None:
|
| 371 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 372 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 373 |
+
|
| 374 |
+
def save_vocabulary(self, vocab_path):
|
| 375 |
+
"""
|
| 376 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 377 |
+
Args:
|
| 378 |
+
vocab_path (:obj:`str`):
|
| 379 |
+
The directory in which to save the vocabulary.
|
| 380 |
+
Returns:
|
| 381 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 382 |
+
"""
|
| 383 |
+
index = 0
|
| 384 |
+
if os.path.isdir(vocab_path):
|
| 385 |
+
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
| 386 |
+
else:
|
| 387 |
+
vocab_file = vocab_path
|
| 388 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 389 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 390 |
+
if index != token_index:
|
| 391 |
+
logger.warning(
|
| 392 |
+
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
| 393 |
+
" Please check that the vocabulary is not corrupted!".format(vocab_file)
|
| 394 |
+
)
|
| 395 |
+
index = token_index
|
| 396 |
+
writer.write(token + "\n")
|
| 397 |
+
index += 1
|
| 398 |
+
return (vocab_file,)
|
tokenizer/__pycache__/my_tokenizers.cpython-310.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
tokenizer/my_tokenizers.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import codecs
|
| 6 |
+
import unicodedata
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
from transformers import PreTrainedTokenizer
|
| 9 |
+
from SmilesPE.tokenizer import SPE_Tokenizer
|
| 10 |
+
|
| 11 |
+
def load_vocab(vocab_file):
|
| 12 |
+
"""Loads a vocabulary file into a dictionary."""
|
| 13 |
+
vocab = collections.OrderedDict()
|
| 14 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 15 |
+
tokens = reader.readlines()
|
| 16 |
+
for index, token in enumerate(tokens):
|
| 17 |
+
token = token.rstrip("\n")
|
| 18 |
+
vocab[token] = index
|
| 19 |
+
return vocab
|
| 20 |
+
|
| 21 |
+
class Atomwise_Tokenizer(object):
|
| 22 |
+
"""Run atom-level SMILES tokenization"""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
""" Constructs a atom-level Tokenizer.
|
| 26 |
+
"""
|
| 27 |
+
# self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 28 |
+
self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 29 |
+
|
| 30 |
+
self.regex = re.compile(self.regex_pattern)
|
| 31 |
+
|
| 32 |
+
def tokenize(self, text):
|
| 33 |
+
""" Basic Tokenization of a SMILES.
|
| 34 |
+
"""
|
| 35 |
+
tokens = [token for token in self.regex.findall(text)]
|
| 36 |
+
return tokens
|
| 37 |
+
|
| 38 |
+
class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
|
| 39 |
+
r"""
|
| 40 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 41 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 42 |
+
should refer to the superclass for more information regarding methods.
|
| 43 |
+
Args:
|
| 44 |
+
vocab_file (:obj:`string`):
|
| 45 |
+
File containing the vocabulary.
|
| 46 |
+
spe_file (:obj:`string`):
|
| 47 |
+
File containing the trained SMILES Pair Encoding vocabulary.
|
| 48 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 49 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 50 |
+
token instead.
|
| 51 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 52 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 53 |
+
for sequence classification or for a text and a question for question answering.
|
| 54 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 55 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 56 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 57 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 58 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 59 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 60 |
+
special tokens.
|
| 61 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 62 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 63 |
+
modeling. This is the token which the model will try to predict.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, vocab_file, spe_file,
|
| 67 |
+
unk_token="[UNK]",
|
| 68 |
+
sep_token="[SEP]",
|
| 69 |
+
pad_token="[PAD]",
|
| 70 |
+
cls_token="[CLS]",
|
| 71 |
+
mask_token="[MASK]",
|
| 72 |
+
**kwargs):
|
| 73 |
+
if not os.path.isfile(vocab_file):
|
| 74 |
+
raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
|
| 75 |
+
if not os.path.isfile(spe_file):
|
| 76 |
+
raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
|
| 77 |
+
|
| 78 |
+
self.vocab = load_vocab(vocab_file)
|
| 79 |
+
self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
|
| 80 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 81 |
+
self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
|
| 82 |
+
|
| 83 |
+
super().__init__(
|
| 84 |
+
unk_token=unk_token,
|
| 85 |
+
sep_token=sep_token,
|
| 86 |
+
pad_token=pad_token,
|
| 87 |
+
cls_token=cls_token,
|
| 88 |
+
mask_token=mask_token,
|
| 89 |
+
**kwargs)
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def vocab_size(self):
|
| 93 |
+
return len(self.vocab)
|
| 94 |
+
|
| 95 |
+
def get_vocab(self):
|
| 96 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 97 |
+
|
| 98 |
+
def _tokenize(self, text):
|
| 99 |
+
return self.spe_tokenizer.tokenize(text).split(' ')
|
| 100 |
+
|
| 101 |
+
def _convert_token_to_id(self, token):
|
| 102 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 103 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 104 |
+
|
| 105 |
+
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
| 106 |
+
text = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
| 107 |
+
return self.convert_tokens_to_string(text)
|
| 108 |
+
|
| 109 |
+
def _convert_id_to_token(self, index):
|
| 110 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 111 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 112 |
+
|
| 113 |
+
def convert_tokens_to_string(self, tokens):
|
| 114 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 115 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 116 |
+
return out_string
|
| 117 |
+
|
| 118 |
+
def build_inputs_with_special_tokens(
|
| 119 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 120 |
+
) -> List[int]:
|
| 121 |
+
"""
|
| 122 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 123 |
+
by concatenating and adding special tokens.
|
| 124 |
+
A BERT sequence has the following format:
|
| 125 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 126 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 127 |
+
Args:
|
| 128 |
+
token_ids_0 (:obj:`List[int]`):
|
| 129 |
+
List of IDs to which the special tokens will be added
|
| 130 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 131 |
+
Optional second list of IDs for sequence pairs.
|
| 132 |
+
Returns:
|
| 133 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 134 |
+
"""
|
| 135 |
+
if token_ids_1 is None:
|
| 136 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 137 |
+
cls = [self.cls_token_id]
|
| 138 |
+
sep = [self.sep_token_id]
|
| 139 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 140 |
+
|
| 141 |
+
def get_special_tokens_mask(
|
| 142 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 143 |
+
) -> List[int]:
|
| 144 |
+
"""
|
| 145 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 146 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 147 |
+
Args:
|
| 148 |
+
token_ids_0 (:obj:`List[int]`):
|
| 149 |
+
List of ids.
|
| 150 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 151 |
+
Optional second list of IDs for sequence pairs.
|
| 152 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 153 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 154 |
+
Returns:
|
| 155 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
if already_has_special_tokens:
|
| 159 |
+
if token_ids_1 is not None:
|
| 160 |
+
raise ValueError(
|
| 161 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 162 |
+
"ids is already formated with special tokens for the model."
|
| 163 |
+
)
|
| 164 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 165 |
+
|
| 166 |
+
if token_ids_1 is not None:
|
| 167 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 168 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 169 |
+
|
| 170 |
+
def create_token_type_ids_from_sequences(
|
| 171 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 172 |
+
) -> List[int]:
|
| 173 |
+
"""
|
| 174 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 175 |
+
A BERT sequence pair mask has the following format:
|
| 176 |
+
::
|
| 177 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 178 |
+
| first sequence | second sequence |
|
| 179 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 180 |
+
Args:
|
| 181 |
+
token_ids_0 (:obj:`List[int]`):
|
| 182 |
+
List of ids.
|
| 183 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 184 |
+
Optional second list of IDs for sequence pairs.
|
| 185 |
+
Returns:
|
| 186 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 187 |
+
sequence(s).
|
| 188 |
+
"""
|
| 189 |
+
sep = [self.sep_token_id]
|
| 190 |
+
cls = [self.cls_token_id]
|
| 191 |
+
if token_ids_1 is None:
|
| 192 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 193 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 194 |
+
|
| 195 |
+
def save_vocabulary(self, vocab_path):
|
| 196 |
+
"""
|
| 197 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 198 |
+
Args:
|
| 199 |
+
vocab_path (:obj:`str`):
|
| 200 |
+
The directory in which to save the vocabulary.
|
| 201 |
+
Returns:
|
| 202 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 203 |
+
"""
|
| 204 |
+
index = 0
|
| 205 |
+
if os.path.isdir(vocab_path):
|
| 206 |
+
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
| 207 |
+
else:
|
| 208 |
+
vocab_file = vocab_path
|
| 209 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 210 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 211 |
+
if index != token_index:
|
| 212 |
+
logger.warning(
|
| 213 |
+
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
| 214 |
+
" Please check that the vocabulary is not corrupted!".format(vocab_file)
|
| 215 |
+
)
|
| 216 |
+
index = token_index
|
| 217 |
+
writer.write(token + "\n")
|
| 218 |
+
index += 1
|
| 219 |
+
return (vocab_file,)
|
| 220 |
+
|
| 221 |
+
class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
|
| 222 |
+
r"""
|
| 223 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 224 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 225 |
+
should refer to the superclass for more information regarding methods.
|
| 226 |
+
Args:
|
| 227 |
+
vocab_file (:obj:`string`):
|
| 228 |
+
File containing the vocabulary.
|
| 229 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 230 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 231 |
+
token instead.
|
| 232 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 233 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 234 |
+
for sequence classification or for a text and a question for question answering.
|
| 235 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 236 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 237 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 238 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 239 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 240 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 241 |
+
special tokens.
|
| 242 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 243 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 244 |
+
modeling. This is the token which the model will try to predict.
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def __init__(
|
| 248 |
+
self,
|
| 249 |
+
vocab_file,
|
| 250 |
+
unk_token="[UNK]",
|
| 251 |
+
sep_token="[SEP]",
|
| 252 |
+
pad_token="[PAD]",
|
| 253 |
+
cls_token="[CLS]",
|
| 254 |
+
mask_token="[MASK]",
|
| 255 |
+
**kwargs
|
| 256 |
+
):
|
| 257 |
+
super().__init__(
|
| 258 |
+
unk_token=unk_token,
|
| 259 |
+
sep_token=sep_token,
|
| 260 |
+
pad_token=pad_token,
|
| 261 |
+
cls_token=cls_token,
|
| 262 |
+
mask_token=mask_token,
|
| 263 |
+
**kwargs,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
if not os.path.isfile(vocab_file):
|
| 267 |
+
raise ValueError(
|
| 268 |
+
"Can't find a vocabulary file at path '{}'.".format(vocab_file)
|
| 269 |
+
)
|
| 270 |
+
self.vocab = load_vocab(vocab_file)
|
| 271 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 272 |
+
self.tokenizer = Atomwise_Tokenizer()
|
| 273 |
+
|
| 274 |
+
@property
|
| 275 |
+
def vocab_size(self):
|
| 276 |
+
return len(self.vocab)
|
| 277 |
+
|
| 278 |
+
def get_vocab(self):
|
| 279 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 280 |
+
|
| 281 |
+
def _tokenize(self, text):
|
| 282 |
+
return self.tokenizer.tokenize(text)
|
| 283 |
+
|
| 284 |
+
def _convert_token_to_id(self, token):
|
| 285 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 286 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 287 |
+
|
| 288 |
+
def _convert_id_to_token(self, index):
|
| 289 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 290 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 291 |
+
|
| 292 |
+
def convert_tokens_to_string(self, tokens):
|
| 293 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 294 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 295 |
+
return out_string
|
| 296 |
+
|
| 297 |
+
def build_inputs_with_special_tokens(
|
| 298 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 299 |
+
) -> List[int]:
|
| 300 |
+
"""
|
| 301 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 302 |
+
by concatenating and adding special tokens.
|
| 303 |
+
A BERT sequence has the following format:
|
| 304 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 305 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 306 |
+
Args:
|
| 307 |
+
token_ids_0 (:obj:`List[int]`):
|
| 308 |
+
List of IDs to which the special tokens will be added
|
| 309 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 310 |
+
Optional second list of IDs for sequence pairs.
|
| 311 |
+
Returns:
|
| 312 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 313 |
+
"""
|
| 314 |
+
if token_ids_1 is None:
|
| 315 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 316 |
+
cls = [self.cls_token_id]
|
| 317 |
+
sep = [self.sep_token_id]
|
| 318 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 319 |
+
|
| 320 |
+
def get_special_tokens_mask(
|
| 321 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 322 |
+
) -> List[int]:
|
| 323 |
+
"""
|
| 324 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 325 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 326 |
+
Args:
|
| 327 |
+
token_ids_0 (:obj:`List[int]`):
|
| 328 |
+
List of ids.
|
| 329 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 330 |
+
Optional second list of IDs for sequence pairs.
|
| 331 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 332 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 333 |
+
Returns:
|
| 334 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
if already_has_special_tokens:
|
| 338 |
+
if token_ids_1 is not None:
|
| 339 |
+
raise ValueError(
|
| 340 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 341 |
+
"ids is already formated with special tokens for the model."
|
| 342 |
+
)
|
| 343 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 344 |
+
|
| 345 |
+
if token_ids_1 is not None:
|
| 346 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 347 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 348 |
+
|
| 349 |
+
def create_token_type_ids_from_sequences(
|
| 350 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 351 |
+
) -> List[int]:
|
| 352 |
+
"""
|
| 353 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 354 |
+
A BERT sequence pair mask has the following format:
|
| 355 |
+
::
|
| 356 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 357 |
+
| first sequence | second sequence |
|
| 358 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 359 |
+
Args:
|
| 360 |
+
token_ids_0 (:obj:`List[int]`):
|
| 361 |
+
List of ids.
|
| 362 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 363 |
+
Optional second list of IDs for sequence pairs.
|
| 364 |
+
Returns:
|
| 365 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 366 |
+
sequence(s).
|
| 367 |
+
"""
|
| 368 |
+
sep = [self.sep_token_id]
|
| 369 |
+
cls = [self.cls_token_id]
|
| 370 |
+
if token_ids_1 is None:
|
| 371 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 372 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 373 |
+
|
| 374 |
+
def save_vocabulary(self, vocab_path):
|
| 375 |
+
"""
|
| 376 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 377 |
+
Args:
|
| 378 |
+
vocab_path (:obj:`str`):
|
| 379 |
+
The directory in which to save the vocabulary.
|
| 380 |
+
Returns:
|
| 381 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 382 |
+
"""
|
| 383 |
+
index = 0
|
| 384 |
+
if os.path.isdir(vocab_path):
|
| 385 |
+
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
| 386 |
+
else:
|
| 387 |
+
vocab_file = vocab_path
|
| 388 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 389 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 390 |
+
if index != token_index:
|
| 391 |
+
logger.warning(
|
| 392 |
+
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
| 393 |
+
" Please check that the vocabulary is not corrupted!".format(vocab_file)
|
| 394 |
+
)
|
| 395 |
+
index = token_index
|
| 396 |
+
writer.write(token + "\n")
|
| 397 |
+
index += 1
|
| 398 |
+
return (vocab_file,)
|
tokenizer/new_splits.txt
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
c 1
|
| 2 |
+
c 2
|
| 3 |
+
c 3
|
| 4 |
+
c 4
|
| 5 |
+
c 5
|
| 6 |
+
c 6
|
| 7 |
+
c 7
|
| 8 |
+
c 8
|
| 9 |
+
c 9
|
| 10 |
+
( c1
|
| 11 |
+
( c2
|
| 12 |
+
c1 )
|
| 13 |
+
c2 )
|
| 14 |
+
n 1
|
| 15 |
+
n 2
|
| 16 |
+
n 3
|
| 17 |
+
n 4
|
| 18 |
+
n 5
|
| 19 |
+
n 6
|
| 20 |
+
n 7
|
| 21 |
+
n 8
|
| 22 |
+
n 9
|
| 23 |
+
( n1
|
| 24 |
+
( n2
|
| 25 |
+
n1 )
|
| 26 |
+
n2 )
|
| 27 |
+
O 1
|
| 28 |
+
O 2
|
| 29 |
+
O 3
|
| 30 |
+
O 4
|
| 31 |
+
O 5
|
| 32 |
+
O 6
|
| 33 |
+
O 7
|
| 34 |
+
O 8
|
| 35 |
+
O 9
|
| 36 |
+
( O1
|
| 37 |
+
( O2
|
| 38 |
+
O2 )
|
| 39 |
+
O2 )
|
| 40 |
+
= O
|
| 41 |
+
= C
|
| 42 |
+
= c
|
| 43 |
+
= N
|
| 44 |
+
= n
|
| 45 |
+
=C C
|
| 46 |
+
=C N
|
| 47 |
+
=C c
|
| 48 |
+
=c c
|
| 49 |
+
=N C
|
| 50 |
+
=N c
|
| 51 |
+
=n C
|
| 52 |
+
=n c
|
| 53 |
+
# N
|
| 54 |
+
# C
|
| 55 |
+
#N C
|
| 56 |
+
#C C
|
| 57 |
+
#C N
|
| 58 |
+
#N N
|
| 59 |
+
( C
|
| 60 |
+
C )
|
| 61 |
+
( O
|
| 62 |
+
O )
|
| 63 |
+
( N
|
| 64 |
+
N )
|
| 65 |
+
Br c
|
| 66 |
+
( =O
|
| 67 |
+
(=O )
|
| 68 |
+
C (=O)
|
| 69 |
+
C =O
|
| 70 |
+
C =N
|
| 71 |
+
C #N
|
| 72 |
+
C #C
|
| 73 |
+
C C
|
| 74 |
+
CC C
|
| 75 |
+
CC N
|
| 76 |
+
CC O
|
| 77 |
+
CC S
|
| 78 |
+
CC c
|
| 79 |
+
CC n
|
| 80 |
+
C N
|
| 81 |
+
CN C
|
| 82 |
+
CN c
|
| 83 |
+
C O
|
| 84 |
+
CO C
|
| 85 |
+
CO N
|
| 86 |
+
CO c
|
| 87 |
+
C S
|
| 88 |
+
CS C
|
| 89 |
+
CS S
|
| 90 |
+
CS c
|
| 91 |
+
C c
|
| 92 |
+
Cl c
|
| 93 |
+
C n
|
| 94 |
+
F c
|
| 95 |
+
N C
|
| 96 |
+
NC C
|
| 97 |
+
NC c
|
| 98 |
+
N N
|
| 99 |
+
N O
|
| 100 |
+
N c
|
| 101 |
+
N n
|
| 102 |
+
O C
|
| 103 |
+
OC C
|
| 104 |
+
OC O
|
| 105 |
+
OC c
|
| 106 |
+
O N
|
| 107 |
+
O O
|
| 108 |
+
O c
|
| 109 |
+
S C
|
| 110 |
+
SC C
|
| 111 |
+
SC c
|
| 112 |
+
S S
|
| 113 |
+
S c
|
| 114 |
+
c c
|
| 115 |
+
cc c
|
| 116 |
+
cc n
|
| 117 |
+
cc o
|
| 118 |
+
cc s
|
| 119 |
+
cc cc
|
| 120 |
+
c n
|
| 121 |
+
cn c
|
| 122 |
+
cn n
|
| 123 |
+
c o
|
| 124 |
+
co c
|
| 125 |
+
c s
|
| 126 |
+
cs c
|
| 127 |
+
cs n
|
| 128 |
+
n c
|
| 129 |
+
nc c
|
| 130 |
+
nc n
|
| 131 |
+
nc o
|
| 132 |
+
nc s
|
| 133 |
+
n n
|
| 134 |
+
nn c
|
| 135 |
+
nn n
|
| 136 |
+
n o
|
| 137 |
+
no c
|
| 138 |
+
no n
|
| 139 |
+
n s
|
| 140 |
+
ns c
|
| 141 |
+
ns n
|
| 142 |
+
o c
|
| 143 |
+
oc c
|
| 144 |
+
o n
|
| 145 |
+
s c
|
| 146 |
+
sc c
|
| 147 |
+
sc n
|
| 148 |
+
s n
|
| 149 |
+
N P
|
| 150 |
+
P N
|
| 151 |
+
C P
|
| 152 |
+
P C
|
| 153 |
+
N S
|
| 154 |
+
S N
|
| 155 |
+
C S
|
| 156 |
+
S C
|
| 157 |
+
S P
|
| 158 |
+
P S
|
| 159 |
+
C I
|
tokenizer/new_vocab.txt
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[PAD]
|
| 2 |
+
[UNK]
|
| 3 |
+
[CLS]
|
| 4 |
+
[SEP]
|
| 5 |
+
[MASK]
|
| 6 |
+
#
|
| 7 |
+
%
|
| 8 |
+
(
|
| 9 |
+
)
|
| 10 |
+
+
|
| 11 |
+
-
|
| 12 |
+
/
|
| 13 |
+
0
|
| 14 |
+
1
|
| 15 |
+
2
|
| 16 |
+
3
|
| 17 |
+
4
|
| 18 |
+
5
|
| 19 |
+
6
|
| 20 |
+
7
|
| 21 |
+
8
|
| 22 |
+
9
|
| 23 |
+
=
|
| 24 |
+
@
|
| 25 |
+
A
|
| 26 |
+
B
|
| 27 |
+
Br
|
| 28 |
+
Brc
|
| 29 |
+
C
|
| 30 |
+
CC
|
| 31 |
+
CCC
|
| 32 |
+
CCN
|
| 33 |
+
CCO
|
| 34 |
+
CCS
|
| 35 |
+
CCc
|
| 36 |
+
CCn
|
| 37 |
+
CN
|
| 38 |
+
CNC
|
| 39 |
+
CNc
|
| 40 |
+
CO
|
| 41 |
+
COC
|
| 42 |
+
CON
|
| 43 |
+
COc
|
| 44 |
+
CS
|
| 45 |
+
CSC
|
| 46 |
+
CSS
|
| 47 |
+
CSc
|
| 48 |
+
Cc
|
| 49 |
+
Cl
|
| 50 |
+
Clc
|
| 51 |
+
Cn
|
| 52 |
+
F
|
| 53 |
+
Fc
|
| 54 |
+
H
|
| 55 |
+
I
|
| 56 |
+
K
|
| 57 |
+
L
|
| 58 |
+
M
|
| 59 |
+
N
|
| 60 |
+
NC
|
| 61 |
+
NCC
|
| 62 |
+
NCc
|
| 63 |
+
NN
|
| 64 |
+
NO
|
| 65 |
+
Nc
|
| 66 |
+
Nn
|
| 67 |
+
O
|
| 68 |
+
OC
|
| 69 |
+
OCC
|
| 70 |
+
OCO
|
| 71 |
+
OCc
|
| 72 |
+
ON
|
| 73 |
+
OO
|
| 74 |
+
Oc
|
| 75 |
+
P
|
| 76 |
+
R
|
| 77 |
+
S
|
| 78 |
+
SC
|
| 79 |
+
SCC
|
| 80 |
+
SCc
|
| 81 |
+
SS
|
| 82 |
+
Sc
|
| 83 |
+
T
|
| 84 |
+
X
|
| 85 |
+
Z
|
| 86 |
+
[
|
| 87 |
+
\\
|
| 88 |
+
(/
|
| 89 |
+
]
|
| 90 |
+
a
|
| 91 |
+
b
|
| 92 |
+
c
|
| 93 |
+
cc
|
| 94 |
+
ccc
|
| 95 |
+
ccn
|
| 96 |
+
cco
|
| 97 |
+
ccs
|
| 98 |
+
cn
|
| 99 |
+
cnc
|
| 100 |
+
cnn
|
| 101 |
+
co
|
| 102 |
+
coc
|
| 103 |
+
cs
|
| 104 |
+
csc
|
| 105 |
+
csn
|
| 106 |
+
e
|
| 107 |
+
g
|
| 108 |
+
i
|
| 109 |
+
l
|
| 110 |
+
n
|
| 111 |
+
nc
|
| 112 |
+
ncc
|
| 113 |
+
ncn
|
| 114 |
+
nco
|
| 115 |
+
ncs
|
| 116 |
+
nn
|
| 117 |
+
nnc
|
| 118 |
+
nnn
|
| 119 |
+
no
|
| 120 |
+
noc
|
| 121 |
+
non
|
| 122 |
+
ns
|
| 123 |
+
nsc
|
| 124 |
+
nsn
|
| 125 |
+
o
|
| 126 |
+
oc
|
| 127 |
+
occ
|
| 128 |
+
on
|
| 129 |
+
p
|
| 130 |
+
r
|
| 131 |
+
s
|
| 132 |
+
sc
|
| 133 |
+
scc
|
| 134 |
+
scn
|
| 135 |
+
sn
|
| 136 |
+
t
|
| 137 |
+
c1
|
| 138 |
+
c2
|
| 139 |
+
c3
|
| 140 |
+
c4
|
| 141 |
+
c5
|
| 142 |
+
c6
|
| 143 |
+
c7
|
| 144 |
+
c8
|
| 145 |
+
c9
|
| 146 |
+
n1
|
| 147 |
+
n2
|
| 148 |
+
n3
|
| 149 |
+
n4
|
| 150 |
+
n5
|
| 151 |
+
n6
|
| 152 |
+
n7
|
| 153 |
+
n8
|
| 154 |
+
n9
|
| 155 |
+
O1
|
| 156 |
+
O2
|
| 157 |
+
O3
|
| 158 |
+
O4
|
| 159 |
+
O5
|
| 160 |
+
O6
|
| 161 |
+
O7
|
| 162 |
+
O8
|
| 163 |
+
O9
|
| 164 |
+
(c1
|
| 165 |
+
(c2
|
| 166 |
+
c1)
|
| 167 |
+
c2)
|
| 168 |
+
(n1
|
| 169 |
+
(n2
|
| 170 |
+
n1)
|
| 171 |
+
n2)
|
| 172 |
+
(O1
|
| 173 |
+
(O2
|
| 174 |
+
O2)
|
| 175 |
+
=O
|
| 176 |
+
=C
|
| 177 |
+
=c
|
| 178 |
+
=N
|
| 179 |
+
=n
|
| 180 |
+
=CC
|
| 181 |
+
=CN
|
| 182 |
+
=Cc
|
| 183 |
+
=cc
|
| 184 |
+
=NC
|
| 185 |
+
=Nc
|
| 186 |
+
=nC
|
| 187 |
+
=nc
|
| 188 |
+
#C
|
| 189 |
+
#CC
|
| 190 |
+
#CN
|
| 191 |
+
#N
|
| 192 |
+
#NC
|
| 193 |
+
#NN
|
| 194 |
+
(C
|
| 195 |
+
C)
|
| 196 |
+
(O
|
| 197 |
+
O)
|
| 198 |
+
(N
|
| 199 |
+
N)
|
| 200 |
+
NP
|
| 201 |
+
PN
|
| 202 |
+
CP
|
| 203 |
+
PC
|
| 204 |
+
NS
|
| 205 |
+
SN
|
| 206 |
+
SP
|
| 207 |
+
PS
|
| 208 |
+
C(=O)
|
| 209 |
+
(/Br)
|
| 210 |
+
(/C#N)
|
| 211 |
+
(/C)
|
| 212 |
+
(/C=N)
|
| 213 |
+
(/C=O)
|
| 214 |
+
(/CBr)
|
| 215 |
+
(/CC)
|
| 216 |
+
(/CCC)
|
| 217 |
+
(/CCF)
|
| 218 |
+
(/CCN)
|
| 219 |
+
(/CCO)
|
| 220 |
+
(/CCl)
|
| 221 |
+
(/CI)
|
| 222 |
+
(/CN)
|
| 223 |
+
(/CO)
|
| 224 |
+
(/CS)
|
| 225 |
+
(/Cl)
|
| 226 |
+
(/F)
|
| 227 |
+
(/I)
|
| 228 |
+
(/N)
|
| 229 |
+
(/NC)
|
| 230 |
+
(/NCC)
|
| 231 |
+
(/NO)
|
| 232 |
+
(/O)
|
| 233 |
+
(/OC)
|
| 234 |
+
(/OCC)
|
| 235 |
+
(/S)
|
| 236 |
+
(/SC)
|
| 237 |
+
(=C)
|
| 238 |
+
(=C/C)
|
| 239 |
+
(=C/F)
|
| 240 |
+
(=C/I)
|
| 241 |
+
(=C/N)
|
| 242 |
+
(=C/O)
|
| 243 |
+
(=CBr)
|
| 244 |
+
(=CC)
|
| 245 |
+
(=CCF)
|
| 246 |
+
(=CCN)
|
| 247 |
+
(=CCO)
|
| 248 |
+
(=CCl)
|
| 249 |
+
(=CF)
|
| 250 |
+
(=CI)
|
| 251 |
+
(=CN)
|
| 252 |
+
(=CO)
|
| 253 |
+
(=C\\C)
|
| 254 |
+
(=C\\F)
|
| 255 |
+
(=C\\I)
|
| 256 |
+
(=C\\N)
|
| 257 |
+
(=C\\O)
|
| 258 |
+
(=N)
|
| 259 |
+
(=N/C)
|
| 260 |
+
(=N/N)
|
| 261 |
+
(=N/O)
|
| 262 |
+
(=NBr)
|
| 263 |
+
(=NC)
|
| 264 |
+
(=NCC)
|
| 265 |
+
(=NCl)
|
| 266 |
+
(=NN)
|
| 267 |
+
(=NO)
|
| 268 |
+
(=NOC)
|
| 269 |
+
(=N\\C)
|
| 270 |
+
(=N\\N)
|
| 271 |
+
(=N\\O)
|
| 272 |
+
(=O)
|
| 273 |
+
(=S)
|
| 274 |
+
(B)
|
| 275 |
+
(Br)
|
| 276 |
+
(C#C)
|
| 277 |
+
(C#CC)
|
| 278 |
+
(C#CI)
|
| 279 |
+
(C#CO)
|
| 280 |
+
(C#N)
|
| 281 |
+
(C#SN)
|
| 282 |
+
(C)
|
| 283 |
+
(C=C)
|
| 284 |
+
(C=CF)
|
| 285 |
+
(C=CI)
|
| 286 |
+
(C=N)
|
| 287 |
+
(C=NN)
|
| 288 |
+
(C=NO)
|
| 289 |
+
(C=O)
|
| 290 |
+
(C=S)
|
| 291 |
+
(CBr)
|
| 292 |
+
(CC#C)
|
| 293 |
+
(CC#N)
|
| 294 |
+
(CC)
|
| 295 |
+
(CC=C)
|
| 296 |
+
(CC=O)
|
| 297 |
+
(CCBr)
|
| 298 |
+
(CCC)
|
| 299 |
+
(CCCC)
|
| 300 |
+
(CCCF)
|
| 301 |
+
(CCCI)
|
| 302 |
+
(CCCN)
|
| 303 |
+
(CCCO)
|
| 304 |
+
(CCCS)
|
| 305 |
+
(CCCl)
|
| 306 |
+
(CCF)
|
| 307 |
+
(CCI)
|
| 308 |
+
(CCN)
|
| 309 |
+
(CCNC)
|
| 310 |
+
(CCNN)
|
| 311 |
+
(CCNO)
|
| 312 |
+
(CCO)
|
| 313 |
+
(CCOC)
|
| 314 |
+
(CCON)
|
| 315 |
+
(CCS)
|
| 316 |
+
(CCSC)
|
| 317 |
+
(CCl)
|
| 318 |
+
(CF)
|
| 319 |
+
(CI)
|
| 320 |
+
(CN)
|
| 321 |
+
(CN=O)
|
| 322 |
+
(CNC)
|
| 323 |
+
(CNCC)
|
| 324 |
+
(CNCO)
|
| 325 |
+
(CNN)
|
| 326 |
+
(CNNC)
|
| 327 |
+
(CNO)
|
| 328 |
+
(CNOC)
|
| 329 |
+
(CO)
|
| 330 |
+
(COC)
|
| 331 |
+
(COCC)
|
| 332 |
+
(COCI)
|
| 333 |
+
(COCN)
|
| 334 |
+
(COCO)
|
| 335 |
+
(COF)
|
| 336 |
+
(CON)
|
| 337 |
+
(COO)
|
| 338 |
+
(CS)
|
| 339 |
+
(CSC)
|
| 340 |
+
(CSCC)
|
| 341 |
+
(CSCF)
|
| 342 |
+
(CSO)
|
| 343 |
+
(Cl)
|
| 344 |
+
(F)
|
| 345 |
+
(I)
|
| 346 |
+
(N)
|
| 347 |
+
(N=N)
|
| 348 |
+
(N=NO)
|
| 349 |
+
(N=O)
|
| 350 |
+
(N=S)
|
| 351 |
+
(NBr)
|
| 352 |
+
(NC#N)
|
| 353 |
+
(NC)
|
| 354 |
+
(NC=N)
|
| 355 |
+
(NC=O)
|
| 356 |
+
(NC=S)
|
| 357 |
+
(NCBr)
|
| 358 |
+
(NCC)
|
| 359 |
+
(NCCC)
|
| 360 |
+
(NCCF)
|
| 361 |
+
(NCCN)
|
| 362 |
+
(NCCO)
|
| 363 |
+
(NCCS)
|
| 364 |
+
(NCCl)
|
| 365 |
+
(NCNC)
|
| 366 |
+
(NCO)
|
| 367 |
+
(NCS)
|
| 368 |
+
(NCl)
|
| 369 |
+
(NN)
|
| 370 |
+
(NN=O)
|
| 371 |
+
(NNC)
|
| 372 |
+
(NO)
|
| 373 |
+
(NOC)
|
| 374 |
+
(O)
|
| 375 |
+
(OC#N)
|
| 376 |
+
(OC)
|
| 377 |
+
(OC=C)
|
| 378 |
+
(OC=O)
|
| 379 |
+
(OC=S)
|
| 380 |
+
(OCBr)
|
| 381 |
+
(OCC)
|
| 382 |
+
(OCCC)
|
| 383 |
+
(OCCF)
|
| 384 |
+
(OCCI)
|
| 385 |
+
(OCCN)
|
| 386 |
+
(OCCO)
|
| 387 |
+
(OCCS)
|
| 388 |
+
(OCCl)
|
| 389 |
+
(OCF)
|
| 390 |
+
(OCI)
|
| 391 |
+
(OCO)
|
| 392 |
+
(OCOC)
|
| 393 |
+
(OCON)
|
| 394 |
+
(OCSC)
|
| 395 |
+
(OCl)
|
| 396 |
+
(OI)
|
| 397 |
+
(ON)
|
| 398 |
+
(OO)
|
| 399 |
+
(OOC)
|
| 400 |
+
(OOCC)
|
| 401 |
+
(OOSN)
|
| 402 |
+
(OSC)
|
| 403 |
+
(P)
|
| 404 |
+
(S)
|
| 405 |
+
(SC#N)
|
| 406 |
+
(SC)
|
| 407 |
+
(SCC)
|
| 408 |
+
(SCCC)
|
| 409 |
+
(SCCF)
|
| 410 |
+
(SCCN)
|
| 411 |
+
(SCCO)
|
| 412 |
+
(SCCS)
|
| 413 |
+
(SCCl)
|
| 414 |
+
(SCF)
|
| 415 |
+
(SCN)
|
| 416 |
+
(SCOC)
|
| 417 |
+
(SCSC)
|
| 418 |
+
(SCl)
|
| 419 |
+
(SI)
|
| 420 |
+
(SN)
|
| 421 |
+
(SN=O)
|
| 422 |
+
(SO)
|
| 423 |
+
(SOC)
|
| 424 |
+
(SOOO)
|
| 425 |
+
(SS)
|
| 426 |
+
(SSC)
|
| 427 |
+
(SSCC)
|
| 428 |
+
([At])
|
| 429 |
+
([O-])
|
| 430 |
+
([O])
|
| 431 |
+
([S-])
|
| 432 |
+
(\\Br)
|
| 433 |
+
(\\C#N)
|
| 434 |
+
(\\C)
|
| 435 |
+
(\\C=N)
|
| 436 |
+
(\\C=O)
|
| 437 |
+
(\\CBr)
|
| 438 |
+
(\\CC)
|
| 439 |
+
(\\CCC)
|
| 440 |
+
(\\CCO)
|
| 441 |
+
(\\CCl)
|
| 442 |
+
(\\CF)
|
| 443 |
+
(\\CN)
|
| 444 |
+
(\\CNC)
|
| 445 |
+
(\\CO)
|
| 446 |
+
(\\COC)
|
| 447 |
+
(\\Cl)
|
| 448 |
+
(\\F)
|
| 449 |
+
(\\I)
|
| 450 |
+
(\\N)
|
| 451 |
+
(\\NC)
|
| 452 |
+
(\\NCC)
|
| 453 |
+
(\\NN)
|
| 454 |
+
(\\NO)
|
| 455 |
+
(\\NOC)
|
| 456 |
+
(\\O)
|
| 457 |
+
(\\OC)
|
| 458 |
+
(\\OCC)
|
| 459 |
+
(\\ON)
|
| 460 |
+
(\\S)
|
| 461 |
+
(\\SC)
|
| 462 |
+
(\\SCC)
|
| 463 |
+
[Ag+]
|
| 464 |
+
[Ag-4]
|
| 465 |
+
[Ag]
|
| 466 |
+
[Al-3]
|
| 467 |
+
[Al]
|
| 468 |
+
[As+]
|
| 469 |
+
[AsH3]
|
| 470 |
+
[AsH]
|
| 471 |
+
[As]
|
| 472 |
+
[At]
|
| 473 |
+
[B-]
|
| 474 |
+
[B@-]
|
| 475 |
+
[B@@-]
|
| 476 |
+
[BH-]
|
| 477 |
+
[BH2-]
|
| 478 |
+
[BH3-]
|
| 479 |
+
[B]
|
| 480 |
+
[Ba]
|
| 481 |
+
[Br+2]
|
| 482 |
+
[BrH]
|
| 483 |
+
[Br]
|
| 484 |
+
[C+]
|
| 485 |
+
[C-]
|
| 486 |
+
[C@@H]
|
| 487 |
+
[C@@]
|
| 488 |
+
[C@H]
|
| 489 |
+
[C@]
|
| 490 |
+
[CH-]
|
| 491 |
+
[CH2]
|
| 492 |
+
[CH3]
|
| 493 |
+
[CH]
|
| 494 |
+
[C]
|
| 495 |
+
[CaH2]
|
| 496 |
+
[Ca]
|
| 497 |
+
[Cl+2]
|
| 498 |
+
[Cl+3]
|
| 499 |
+
[Cl+]
|
| 500 |
+
[Cs]
|
| 501 |
+
[FH]
|
| 502 |
+
[F]
|
| 503 |
+
[H]
|
| 504 |
+
[He]
|
| 505 |
+
[I+2]
|
| 506 |
+
[I+3]
|
| 507 |
+
[I+]
|
| 508 |
+
[IH]
|
| 509 |
+
[I]
|
| 510 |
+
[K]
|
| 511 |
+
[Kr]
|
| 512 |
+
[Li+]
|
| 513 |
+
[LiH]
|
| 514 |
+
[MgH2]
|
| 515 |
+
[Mg]
|
| 516 |
+
[N+]
|
| 517 |
+
[N-]
|
| 518 |
+
[N@+]
|
| 519 |
+
[N@@+]
|
| 520 |
+
[N@@]
|
| 521 |
+
[N@]
|
| 522 |
+
[NH+]
|
| 523 |
+
[NH-]
|
| 524 |
+
[NH2+]
|
| 525 |
+
[NH3]
|
| 526 |
+
[NH]
|
| 527 |
+
[N]
|
| 528 |
+
[Na]
|
| 529 |
+
[O+]
|
| 530 |
+
[O-]
|
| 531 |
+
[OH+]
|
| 532 |
+
[OH2]
|
| 533 |
+
[OH]
|
| 534 |
+
[O]
|
| 535 |
+
[P+]
|
| 536 |
+
[P@+]
|
| 537 |
+
[P@@+]
|
| 538 |
+
[P@@]
|
| 539 |
+
[P@]
|
| 540 |
+
[PH2]
|
| 541 |
+
[PH]
|
| 542 |
+
[P]
|
| 543 |
+
[Ra]
|
| 544 |
+
[Rb]
|
| 545 |
+
[S+]
|
| 546 |
+
[S-]
|
| 547 |
+
[S@+]
|
| 548 |
+
[S@@+]
|
| 549 |
+
[S@@]
|
| 550 |
+
[S@]
|
| 551 |
+
[SH+]
|
| 552 |
+
[SH2]
|
| 553 |
+
[SH]
|
| 554 |
+
[S]
|
| 555 |
+
[Se+]
|
| 556 |
+
[Se-2]
|
| 557 |
+
[SeH2]
|
| 558 |
+
[SeH]
|
| 559 |
+
[Se]
|
| 560 |
+
[Si@]
|
| 561 |
+
[SiH2]
|
| 562 |
+
[SiH]
|
| 563 |
+
[Si]
|
| 564 |
+
[SrH2]
|
| 565 |
+
[TeH]
|
| 566 |
+
[Te]
|
| 567 |
+
[Xe]
|
| 568 |
+
[Zn+2]
|
| 569 |
+
[Zn-2]
|
| 570 |
+
[Zn]
|
| 571 |
+
[b-]
|
| 572 |
+
[c+]
|
| 573 |
+
[c-]
|
| 574 |
+
[cH-]
|
| 575 |
+
[cH]
|
| 576 |
+
[c]
|
| 577 |
+
[n+]
|
| 578 |
+
[n-]
|
| 579 |
+
[nH]
|
| 580 |
+
[n]
|
| 581 |
+
[o+]
|
| 582 |
+
[s+]
|
| 583 |
+
[se+]
|
| 584 |
+
[se]
|
| 585 |
+
[te+]
|
| 586 |
+
[te]
|
training_classifiers/.gitignore
ADDED
|
File without changes
|
training_classifiers/.ipynb_checkpoints/binding_affinity_iptm-checkpoint.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
extract_iptm_affinity_csv_all.py
|
| 4 |
+
|
| 5 |
+
Writes:
|
| 6 |
+
- out_dir/wt_iptm_affinity_all.csv
|
| 7 |
+
- out_dir/smiles_iptm_affinity_all.csv
|
| 8 |
+
|
| 9 |
+
Also prints:
|
| 10 |
+
- N
|
| 11 |
+
- Spearman rho (affinity vs iptm)
|
| 12 |
+
- Pearson r (affinity vs iptm)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pandas as pd
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def corr_stats(df: pd.DataFrame, x: str, y: str):
|
| 21 |
+
# pandas handles NaNs if we already dropped them; still be safe
|
| 22 |
+
xx = pd.to_numeric(df[x], errors="coerce")
|
| 23 |
+
yy = pd.to_numeric(df[y], errors="coerce")
|
| 24 |
+
m = xx.notna() & yy.notna()
|
| 25 |
+
xx = xx[m]
|
| 26 |
+
yy = yy[m]
|
| 27 |
+
n = int(m.sum())
|
| 28 |
+
|
| 29 |
+
# Pearson r
|
| 30 |
+
pearson_r = float(xx.corr(yy, method="pearson")) if n > 1 else float("nan")
|
| 31 |
+
# Spearman rho
|
| 32 |
+
spearman_rho = float(xx.corr(yy, method="spearman")) if n > 1 else float("nan")
|
| 33 |
+
|
| 34 |
+
return {"n": n, "pearson_r": pearson_r, "spearman_rho": spearman_rho}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def clean_one(
|
| 38 |
+
in_csv: Path,
|
| 39 |
+
out_csv: Path,
|
| 40 |
+
iptm_col: str,
|
| 41 |
+
affinity_col: str = "affinity",
|
| 42 |
+
keep_cols=(),
|
| 43 |
+
):
|
| 44 |
+
df = pd.read_csv(in_csv)
|
| 45 |
+
|
| 46 |
+
# affinity + iptm must exist
|
| 47 |
+
need = [affinity_col, iptm_col]
|
| 48 |
+
missing = [c for c in need if c not in df.columns]
|
| 49 |
+
if missing:
|
| 50 |
+
raise ValueError(f"{in_csv} missing columns: {missing}. Found: {list(df.columns)}")
|
| 51 |
+
|
| 52 |
+
# coerce numeric
|
| 53 |
+
df[affinity_col] = pd.to_numeric(df[affinity_col], errors="coerce")
|
| 54 |
+
df[iptm_col] = pd.to_numeric(df[iptm_col], errors="coerce")
|
| 55 |
+
|
| 56 |
+
# drop NaNs in either
|
| 57 |
+
df = df.dropna(subset=[affinity_col, iptm_col]).reset_index(drop=True)
|
| 58 |
+
|
| 59 |
+
# output cols (standardize names)
|
| 60 |
+
out = pd.DataFrame({
|
| 61 |
+
"affinity": df[affinity_col].astype(float),
|
| 62 |
+
"iptm": df[iptm_col].astype(float),
|
| 63 |
+
})
|
| 64 |
+
|
| 65 |
+
# keep split if present (handy for coloring later, but not used for corr)
|
| 66 |
+
if "split" in df.columns:
|
| 67 |
+
out.insert(0, "split", df["split"].astype(str))
|
| 68 |
+
|
| 69 |
+
# optional extras for labeling/debug
|
| 70 |
+
for c in keep_cols:
|
| 71 |
+
if c in df.columns:
|
| 72 |
+
out[c] = df[c]
|
| 73 |
+
|
| 74 |
+
out_csv.parent.mkdir(parents=True, exist_ok=True)
|
| 75 |
+
out.to_csv(out_csv, index=False)
|
| 76 |
+
|
| 77 |
+
stats = corr_stats(out, "iptm", "affinity")
|
| 78 |
+
print(f"[write] {out_csv}")
|
| 79 |
+
print(f" N={stats['n']} | Pearson r={stats['pearson_r']:.4f} | Spearman rho={stats['spearman_rho']:.4f}")
|
| 80 |
+
|
| 81 |
+
# also save stats json next to csv
|
| 82 |
+
stats_path = out_csv.with_suffix(".stats.json")
|
| 83 |
+
with open(stats_path, "w") as f:
|
| 84 |
+
import json
|
| 85 |
+
json.dump(
|
| 86 |
+
{
|
| 87 |
+
"input_csv": str(in_csv),
|
| 88 |
+
"output_csv": str(out_csv),
|
| 89 |
+
"iptm_col": iptm_col,
|
| 90 |
+
"affinity_col": affinity_col,
|
| 91 |
+
**stats,
|
| 92 |
+
},
|
| 93 |
+
f,
|
| 94 |
+
indent=2,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def main():
|
| 99 |
+
import argparse
|
| 100 |
+
ap = argparse.ArgumentParser()
|
| 101 |
+
ap.add_argument("--wt_meta_csv", type=str, required=True)
|
| 102 |
+
ap.add_argument("--smiles_meta_csv", type=str, required=True)
|
| 103 |
+
ap.add_argument("--out_dir", type=str, required=True)
|
| 104 |
+
|
| 105 |
+
ap.add_argument("--wt_iptm_col", type=str, default="wt_iptm_score")
|
| 106 |
+
ap.add_argument("--smiles_iptm_col", type=str, default="smiles_iptm_score")
|
| 107 |
+
ap.add_argument("--affinity_col", type=str, default="affinity")
|
| 108 |
+
args = ap.parse_args()
|
| 109 |
+
|
| 110 |
+
out_dir = Path(args.out_dir)
|
| 111 |
+
|
| 112 |
+
clean_one(
|
| 113 |
+
Path(args.wt_meta_csv),
|
| 114 |
+
out_dir / "wt_iptm_affinity_all.csv",
|
| 115 |
+
iptm_col=args.wt_iptm_col,
|
| 116 |
+
affinity_col=args.affinity_col,
|
| 117 |
+
keep_cols=("seq1", "seq2", "Fasta2SMILES", "REACT_SMILES"),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
clean_one(
|
| 121 |
+
Path(args.smiles_meta_csv),
|
| 122 |
+
out_dir / "smiles_iptm_affinity_all.csv",
|
| 123 |
+
iptm_col=args.smiles_iptm_col,
|
| 124 |
+
affinity_col=args.affinity_col,
|
| 125 |
+
keep_cols=("seq1", "seq2", "Fasta2SMILES", "REACT_SMILES", "smiles_sequence"),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
print(f"\n[DONE] CSVs + stats JSONs in: {out_dir}")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
main()
|
training_classifiers/.ipynb_checkpoints/binding_affinity_split-checkpoint.py
ADDED
|
@@ -0,0 +1,847 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import os
|
| 3 |
+
import math
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import sys
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
# tqdm is optional; we’ll disable it by default in notebooks
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
sys.path.append("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight")
|
| 16 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 17 |
+
|
| 18 |
+
from datasets import Dataset, DatasetDict, Features, Value, Sequence as HFSequence
|
| 19 |
+
from transformers import AutoTokenizer, EsmModel, AutoModelForMaskedLM
|
| 20 |
+
|
| 21 |
+
# -------------------------
|
| 22 |
+
# Config
|
| 23 |
+
# -------------------------
|
| 24 |
+
CSV_PATH = Path("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/c-binding_with_openfold_scores.csv")
|
| 25 |
+
|
| 26 |
+
OUT_ROOT = Path(
|
| 27 |
+
"/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_data_cleaned/binding_affinity"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# WT (seq) embedding model
|
| 31 |
+
WT_MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
|
| 32 |
+
WT_MAX_LEN = 1022
|
| 33 |
+
WT_BATCH = 32
|
| 34 |
+
|
| 35 |
+
# SMILES embedding model + tokenizer
|
| 36 |
+
SMI_MODEL_NAME = "aaronfeller/PeptideCLM-23M-all"
|
| 37 |
+
TOKENIZER_VOCAB = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_vocab.txt"
|
| 38 |
+
TOKENIZER_SPLITS = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_splits.txt"
|
| 39 |
+
SMI_MAX_LEN = 768
|
| 40 |
+
SMI_BATCH = 128
|
| 41 |
+
|
| 42 |
+
# Split config
|
| 43 |
+
TRAIN_FRAC = 0.80
|
| 44 |
+
RANDOM_SEED = 1986
|
| 45 |
+
AFFINITY_Q_BINS = 30
|
| 46 |
+
|
| 47 |
+
# Columns expected in CSV
|
| 48 |
+
COL_SEQ1 = "seq1"
|
| 49 |
+
COL_SEQ2 = "seq2"
|
| 50 |
+
COL_AFF = "affinity"
|
| 51 |
+
COL_F2S = "Fasta2SMILES"
|
| 52 |
+
COL_REACT = "REACT_SMILES"
|
| 53 |
+
COL_WT_IPTM = "wt_iptm_score"
|
| 54 |
+
COL_SMI_IPTM = "smiles_iptm_score"
|
| 55 |
+
|
| 56 |
+
# Device
|
| 57 |
+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 58 |
+
|
| 59 |
+
# -------------------------
|
| 60 |
+
# Quiet / notebook-safe output controls
|
| 61 |
+
# -------------------------
|
| 62 |
+
QUIET = True # suppress most prints
|
| 63 |
+
USE_TQDM = False # disable tqdm bars (recommended in Jupyter to avoid crashing)
|
| 64 |
+
LOG_FILE = None # optionally: OUT_ROOT / "build.log"
|
| 65 |
+
|
| 66 |
+
def log(msg: str):
|
| 67 |
+
if LOG_FILE is not None:
|
| 68 |
+
Path(LOG_FILE).parent.mkdir(parents=True, exist_ok=True)
|
| 69 |
+
with open(LOG_FILE, "a") as f:
|
| 70 |
+
f.write(msg.rstrip() + "\n")
|
| 71 |
+
if not QUIET:
|
| 72 |
+
print(msg)
|
| 73 |
+
|
| 74 |
+
def pbar(it, **kwargs):
|
| 75 |
+
return tqdm(it, **kwargs) if USE_TQDM else it
|
| 76 |
+
|
| 77 |
+
@contextmanager
|
| 78 |
+
def section(title: str):
|
| 79 |
+
log(f"\n=== {title} ===")
|
| 80 |
+
yield
|
| 81 |
+
log(f"=== done: {title} ===")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# -------------------------
|
| 85 |
+
# Helpers
|
| 86 |
+
# -------------------------
|
| 87 |
+
def has_uaa(seq: str) -> bool:
|
| 88 |
+
return "X" in str(seq).upper()
|
| 89 |
+
|
| 90 |
+
def affinity_to_class(a: float) -> str:
|
| 91 |
+
# High: >= 9 ; Moderate: [7, 9) ; Low: < 7
|
| 92 |
+
if a >= 9.0:
|
| 93 |
+
return "High"
|
| 94 |
+
elif a >= 7.0:
|
| 95 |
+
return "Moderate"
|
| 96 |
+
else:
|
| 97 |
+
return "Low"
|
| 98 |
+
|
| 99 |
+
def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
|
| 100 |
+
df = df.copy()
|
| 101 |
+
|
| 102 |
+
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
|
| 103 |
+
df = df.dropna(subset=[COL_AFF]).reset_index(drop=True)
|
| 104 |
+
|
| 105 |
+
df["affinity_class"] = df[COL_AFF].apply(affinity_to_class)
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop")
|
| 109 |
+
strat_col = "aff_bin"
|
| 110 |
+
except Exception:
|
| 111 |
+
df["aff_bin"] = df["affinity_class"]
|
| 112 |
+
strat_col = "aff_bin"
|
| 113 |
+
|
| 114 |
+
rng = np.random.RandomState(RANDOM_SEED)
|
| 115 |
+
|
| 116 |
+
df["split"] = None
|
| 117 |
+
for _, g in df.groupby(strat_col, observed=True):
|
| 118 |
+
idx = g.index.to_numpy()
|
| 119 |
+
rng.shuffle(idx)
|
| 120 |
+
n_train = int(math.floor(len(idx) * TRAIN_FRAC))
|
| 121 |
+
df.loc[idx[:n_train], "split"] = "train"
|
| 122 |
+
df.loc[idx[n_train:], "split"] = "val"
|
| 123 |
+
|
| 124 |
+
df["split"] = df["split"].fillna("train")
|
| 125 |
+
return df
|
| 126 |
+
|
| 127 |
+
def _summ(x):
|
| 128 |
+
x = np.asarray(x, dtype=float)
|
| 129 |
+
x = x[~np.isnan(x)]
|
| 130 |
+
if len(x) == 0:
|
| 131 |
+
return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
|
| 132 |
+
return {
|
| 133 |
+
"n": int(len(x)),
|
| 134 |
+
"mean": float(np.mean(x)),
|
| 135 |
+
"std": float(np.std(x)),
|
| 136 |
+
"p50": float(np.quantile(x, 0.50)),
|
| 137 |
+
"p95": float(np.quantile(x, 0.95)),
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
def _len_stats(seqs):
|
| 141 |
+
lens = np.asarray([len(str(s)) for s in seqs], dtype=float)
|
| 142 |
+
if len(lens) == 0:
|
| 143 |
+
return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
|
| 144 |
+
return {
|
| 145 |
+
"n": int(len(lens)),
|
| 146 |
+
"mean": float(lens.mean()),
|
| 147 |
+
"std": float(lens.std()),
|
| 148 |
+
"p50": float(np.quantile(lens, 0.50)),
|
| 149 |
+
"p95": float(np.quantile(lens, 0.95)),
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
def verify_split_before_embedding(
|
| 153 |
+
df2: pd.DataFrame,
|
| 154 |
+
affinity_col: str,
|
| 155 |
+
split_col: str,
|
| 156 |
+
seq_col: str,
|
| 157 |
+
iptm_col: str,
|
| 158 |
+
aff_class_col: str = "affinity_class",
|
| 159 |
+
aff_bins: int = 30,
|
| 160 |
+
save_report_prefix: str | None = None,
|
| 161 |
+
verbose: bool = False,
|
| 162 |
+
):
|
| 163 |
+
"""
|
| 164 |
+
Notebook-safe: by default prints only ONE line via `log()`.
|
| 165 |
+
Optionally writes CSV reports (stats + class proportions).
|
| 166 |
+
"""
|
| 167 |
+
df2 = df2.copy()
|
| 168 |
+
df2[affinity_col] = pd.to_numeric(df2[affinity_col], errors="coerce")
|
| 169 |
+
df2[iptm_col] = pd.to_numeric(df2[iptm_col], errors="coerce")
|
| 170 |
+
|
| 171 |
+
assert split_col in df2.columns, f"Missing split col: {split_col}"
|
| 172 |
+
assert set(df2[split_col].dropna().unique()).issubset({"train", "val"}), f"Unexpected split values: {df2[split_col].unique()}"
|
| 173 |
+
assert df2[affinity_col].notna().any(), "No valid affinity values after coercion."
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
df2["_aff_bin_dbg"] = pd.qcut(df2[affinity_col], q=aff_bins, duplicates="drop")
|
| 177 |
+
except Exception:
|
| 178 |
+
df2["_aff_bin_dbg"] = df2[aff_class_col].astype(str)
|
| 179 |
+
|
| 180 |
+
tr = df2[df2[split_col] == "train"].reset_index(drop=True)
|
| 181 |
+
va = df2[df2[split_col] == "val"].reset_index(drop=True)
|
| 182 |
+
|
| 183 |
+
tr_aff = _summ(tr[affinity_col].to_numpy())
|
| 184 |
+
va_aff = _summ(va[affinity_col].to_numpy())
|
| 185 |
+
tr_len = _len_stats(tr[seq_col].tolist())
|
| 186 |
+
va_len = _len_stats(va[seq_col].tolist())
|
| 187 |
+
|
| 188 |
+
# bin drift
|
| 189 |
+
bin_ct = (
|
| 190 |
+
df2.groupby([split_col, "_aff_bin_dbg"])
|
| 191 |
+
.size()
|
| 192 |
+
.groupby(level=0)
|
| 193 |
+
.apply(lambda s: s / s.sum())
|
| 194 |
+
)
|
| 195 |
+
tr_bins = bin_ct.loc["train"]
|
| 196 |
+
va_bins = bin_ct.loc["val"]
|
| 197 |
+
all_bins = tr_bins.index.union(va_bins.index)
|
| 198 |
+
tr_bins = tr_bins.reindex(all_bins, fill_value=0.0)
|
| 199 |
+
va_bins = va_bins.reindex(all_bins, fill_value=0.0)
|
| 200 |
+
max_bin_diff = float(np.max(np.abs(tr_bins.values - va_bins.values)))
|
| 201 |
+
|
| 202 |
+
msg = (
|
| 203 |
+
f"[split-check] rows={len(df2)} train={len(tr)} val={len(va)} | "
|
| 204 |
+
f"aff(mean±std) train={tr_aff['mean']:.3f}±{tr_aff['std']:.3f} val={va_aff['mean']:.3f}±{va_aff['std']:.3f} | "
|
| 205 |
+
f"len(p50/p95) train={tr_len['p50']:.1f}/{tr_len['p95']:.1f} val={va_len['p50']:.1f}/{va_len['p95']:.1f} | "
|
| 206 |
+
f"max_bin_diff={max_bin_diff:.4f}"
|
| 207 |
+
)
|
| 208 |
+
log(msg)
|
| 209 |
+
|
| 210 |
+
if verbose and (not QUIET):
|
| 211 |
+
class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
|
| 212 |
+
class_prop = class_ct.div(class_ct.sum(axis=1), axis=0)
|
| 213 |
+
print("\n[verbose] affinity_class counts:\n", class_ct)
|
| 214 |
+
print("\n[verbose] affinity_class proportions:\n", class_prop.round(4))
|
| 215 |
+
|
| 216 |
+
if save_report_prefix is not None:
|
| 217 |
+
out = Path(save_report_prefix)
|
| 218 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 219 |
+
|
| 220 |
+
stats_df = pd.DataFrame([
|
| 221 |
+
{"split": "train", **{f"aff_{k}": v for k, v in tr_aff.items()}, **{f"len_{k}": v for k, v in tr_len.items()}},
|
| 222 |
+
{"split": "val", **{f"aff_{k}": v for k, v in va_aff.items()}, **{f"len_{k}": v for k, v in va_len.items()}},
|
| 223 |
+
])
|
| 224 |
+
class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
|
| 225 |
+
class_prop = class_ct.div(class_ct.sum(axis=1), axis=0).reset_index()
|
| 226 |
+
|
| 227 |
+
stats_df.to_csv(out.with_suffix(".stats.csv"), index=False)
|
| 228 |
+
class_prop.to_csv(out.with_suffix(".class_prop.csv"), index=False)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# -------------------------
|
| 232 |
+
# WT pooled (ESM2)
|
| 233 |
+
# -------------------------
|
| 234 |
+
@torch.no_grad()
|
| 235 |
+
def wt_pooled_embeddings(seqs, tokenizer, model, batch_size=32, max_length=1022):
|
| 236 |
+
embs = []
|
| 237 |
+
for i in pbar(range(0, len(seqs), batch_size)):
|
| 238 |
+
batch = seqs[i:i + batch_size]
|
| 239 |
+
inputs = tokenizer(
|
| 240 |
+
batch,
|
| 241 |
+
padding=True,
|
| 242 |
+
truncation=True,
|
| 243 |
+
max_length=max_length,
|
| 244 |
+
return_tensors="pt",
|
| 245 |
+
)
|
| 246 |
+
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
| 247 |
+
out = model(**inputs)
|
| 248 |
+
h = out.last_hidden_state # (B, L, H)
|
| 249 |
+
|
| 250 |
+
attn = inputs["attention_mask"].unsqueeze(-1) # (B, L, 1)
|
| 251 |
+
summed = (h * attn).sum(dim=1) # (B, H)
|
| 252 |
+
denom = attn.sum(dim=1).clamp(min=1e-9) # (B, 1)
|
| 253 |
+
pooled = (summed / denom).detach().cpu().numpy()
|
| 254 |
+
embs.append(pooled)
|
| 255 |
+
|
| 256 |
+
return np.vstack(embs)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# -------------------------
|
| 260 |
+
# WT unpooled (ESM2)
|
| 261 |
+
# -------------------------
|
| 262 |
+
@torch.no_grad()
|
| 263 |
+
def wt_unpooled_one(seq, tokenizer, model, cls_id, eos_id, max_length=1022):
|
| 264 |
+
tok = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt")
|
| 265 |
+
tok = {k: v.to(DEVICE) for k, v in tok.items()}
|
| 266 |
+
out = model(**tok)
|
| 267 |
+
h = out.last_hidden_state[0] # (L, H)
|
| 268 |
+
attn = tok["attention_mask"][0].bool() # (L,)
|
| 269 |
+
ids = tok["input_ids"][0]
|
| 270 |
+
|
| 271 |
+
keep = attn.clone()
|
| 272 |
+
if cls_id is not None:
|
| 273 |
+
keep &= (ids != cls_id)
|
| 274 |
+
if eos_id is not None:
|
| 275 |
+
keep &= (ids != eos_id)
|
| 276 |
+
|
| 277 |
+
return h[keep].detach().cpu().to(torch.float16).numpy()
|
| 278 |
+
|
| 279 |
+
def build_wt_unpooled_dataset(df_split: pd.DataFrame, out_dir: Path, tokenizer, model):
|
| 280 |
+
"""
|
| 281 |
+
Expects df_split to have:
|
| 282 |
+
- target_sequence (seq1)
|
| 283 |
+
- sequence (binder seq2; WT binder)
|
| 284 |
+
- label, affinity_class, COL_AFF, COL_WT_IPTM
|
| 285 |
+
Saves a dataset where each row contains BOTH:
|
| 286 |
+
- target_embedding (Lt,H), target_attention_mask, target_length
|
| 287 |
+
- binder_embedding (Lb,H), binder_attention_mask, binder_length
|
| 288 |
+
"""
|
| 289 |
+
cls_id = tokenizer.cls_token_id
|
| 290 |
+
eos_id = tokenizer.eos_token_id
|
| 291 |
+
H = model.config.hidden_size
|
| 292 |
+
|
| 293 |
+
features = Features({
|
| 294 |
+
"target_sequence": Value("string"),
|
| 295 |
+
"sequence": Value("string"),
|
| 296 |
+
"label": Value("float32"),
|
| 297 |
+
"affinity": Value("float32"),
|
| 298 |
+
"affinity_class": Value("string"),
|
| 299 |
+
|
| 300 |
+
"target_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
|
| 301 |
+
"target_attention_mask": HFSequence(Value("int8")),
|
| 302 |
+
"target_length": Value("int64"),
|
| 303 |
+
|
| 304 |
+
"binder_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
|
| 305 |
+
"binder_attention_mask": HFSequence(Value("int8")),
|
| 306 |
+
"binder_length": Value("int64"),
|
| 307 |
+
|
| 308 |
+
COL_WT_IPTM: Value("float32"),
|
| 309 |
+
COL_AFF: Value("float32"),
|
| 310 |
+
})
|
| 311 |
+
|
| 312 |
+
def gen_rows(df: pd.DataFrame):
|
| 313 |
+
for r in pbar(df.itertuples(index=False), total=len(df)):
|
| 314 |
+
tgt = str(getattr(r, "target_sequence")).strip()
|
| 315 |
+
bnd = str(getattr(r, "sequence")).strip()
|
| 316 |
+
|
| 317 |
+
y = float(getattr(r, "label"))
|
| 318 |
+
aff = float(getattr(r, COL_AFF))
|
| 319 |
+
acls = str(getattr(r, "affinity_class"))
|
| 320 |
+
|
| 321 |
+
iptm = getattr(r, COL_WT_IPTM)
|
| 322 |
+
iptm = float(iptm) if pd.notna(iptm) else np.nan
|
| 323 |
+
|
| 324 |
+
# token embeddings for target + binder (both ESM)
|
| 325 |
+
t_emb = wt_unpooled_one(tgt, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lt,H)
|
| 326 |
+
b_emb = wt_unpooled_one(bnd, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lb,H)
|
| 327 |
+
|
| 328 |
+
t_list = t_emb.tolist()
|
| 329 |
+
b_list = b_emb.tolist()
|
| 330 |
+
Lt = len(t_list)
|
| 331 |
+
Lb = len(b_list)
|
| 332 |
+
|
| 333 |
+
yield {
|
| 334 |
+
"target_sequence": tgt,
|
| 335 |
+
"sequence": bnd,
|
| 336 |
+
"label": np.float32(y),
|
| 337 |
+
"affinity": np.float32(aff),
|
| 338 |
+
"affinity_class": acls,
|
| 339 |
+
|
| 340 |
+
"target_embedding": t_list,
|
| 341 |
+
"target_attention_mask": [1] * Lt,
|
| 342 |
+
"target_length": int(Lt),
|
| 343 |
+
|
| 344 |
+
"binder_embedding": b_list,
|
| 345 |
+
"binder_attention_mask": [1] * Lb,
|
| 346 |
+
"binder_length": int(Lb),
|
| 347 |
+
|
| 348 |
+
COL_WT_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
|
| 349 |
+
COL_AFF: np.float32(aff),
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 353 |
+
ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
|
| 354 |
+
ds.save_to_disk(str(out_dir), max_shard_size="1GB")
|
| 355 |
+
return ds
|
| 356 |
+
|
| 357 |
+
def build_smiles_unpooled_paired_dataset(df_split: pd.DataFrame, out_dir: Path, wt_tokenizer, wt_model_unpooled,
|
| 358 |
+
smi_tok, smi_roformer):
|
| 359 |
+
"""
|
| 360 |
+
df_split must have:
|
| 361 |
+
- target_sequence (seq1)
|
| 362 |
+
- sequence (binder smiles string)
|
| 363 |
+
- label, affinity_class, COL_AFF, COL_SMI_IPTM
|
| 364 |
+
Saves rows with:
|
| 365 |
+
target_embedding (Lt,Ht) from ESM
|
| 366 |
+
binder_embedding (Lb,Hb) from PeptideCLM
|
| 367 |
+
"""
|
| 368 |
+
cls_id = wt_tokenizer.cls_token_id
|
| 369 |
+
eos_id = wt_tokenizer.eos_token_id
|
| 370 |
+
Ht = wt_model_unpooled.config.hidden_size
|
| 371 |
+
|
| 372 |
+
# Infer Hb from one forward pass? easiest: run one mini batch outside in main if you want.
|
| 373 |
+
# Here: we’ll infer from model config if available.
|
| 374 |
+
Hb = getattr(smi_roformer.config, "hidden_size", None)
|
| 375 |
+
if Hb is None:
|
| 376 |
+
Hb = getattr(smi_roformer.config, "dim", None)
|
| 377 |
+
if Hb is None:
|
| 378 |
+
raise ValueError("Cannot infer Hb from smi_roformer config; print(smi_roformer.config) and set Hb manually.")
|
| 379 |
+
|
| 380 |
+
features = Features({
|
| 381 |
+
"target_sequence": Value("string"),
|
| 382 |
+
"sequence": Value("string"),
|
| 383 |
+
"label": Value("float32"),
|
| 384 |
+
"affinity": Value("float32"),
|
| 385 |
+
"affinity_class": Value("string"),
|
| 386 |
+
|
| 387 |
+
"target_embedding": HFSequence(HFSequence(Value("float16"), length=Ht)),
|
| 388 |
+
"target_attention_mask": HFSequence(Value("int8")),
|
| 389 |
+
"target_length": Value("int64"),
|
| 390 |
+
|
| 391 |
+
"binder_embedding": HFSequence(HFSequence(Value("float16"), length=Hb)),
|
| 392 |
+
"binder_attention_mask": HFSequence(Value("int8")),
|
| 393 |
+
"binder_length": Value("int64"),
|
| 394 |
+
|
| 395 |
+
COL_SMI_IPTM: Value("float32"),
|
| 396 |
+
COL_AFF: Value("float32"),
|
| 397 |
+
})
|
| 398 |
+
|
| 399 |
+
def gen_rows(df: pd.DataFrame):
|
| 400 |
+
for r in pbar(df.itertuples(index=False), total=len(df)):
|
| 401 |
+
tgt = str(getattr(r, "target_sequence")).strip()
|
| 402 |
+
bnd = str(getattr(r, "sequence")).strip()
|
| 403 |
+
|
| 404 |
+
y = float(getattr(r, "label"))
|
| 405 |
+
aff = float(getattr(r, COL_AFF))
|
| 406 |
+
acls = str(getattr(r, "affinity_class"))
|
| 407 |
+
|
| 408 |
+
iptm = getattr(r, COL_SMI_IPTM)
|
| 409 |
+
iptm = float(iptm) if pd.notna(iptm) else np.nan
|
| 410 |
+
|
| 411 |
+
# target token embeddings (ESM)
|
| 412 |
+
t_emb = wt_unpooled_one(tgt, wt_tokenizer, wt_model_unpooled, cls_id, eos_id, max_length=WT_MAX_LEN)
|
| 413 |
+
t_list = t_emb.tolist()
|
| 414 |
+
Lt = len(t_list)
|
| 415 |
+
|
| 416 |
+
# binder token embeddings (PeptideCLM) — single-item batch
|
| 417 |
+
_, tok_list, mask_list, lengths = smiles_embed_batch_return_both(
|
| 418 |
+
[bnd], smi_tok, smi_roformer, max_length=SMI_MAX_LEN
|
| 419 |
+
)
|
| 420 |
+
b_emb = tok_list[0] # np.float16 (Lb, Hb)
|
| 421 |
+
b_list = b_emb.tolist()
|
| 422 |
+
Lb = int(lengths[0])
|
| 423 |
+
b_mask = mask_list[0].astype(np.int8).tolist()
|
| 424 |
+
|
| 425 |
+
yield {
|
| 426 |
+
"target_sequence": tgt,
|
| 427 |
+
"sequence": bnd,
|
| 428 |
+
"label": np.float32(y),
|
| 429 |
+
"affinity": np.float32(aff),
|
| 430 |
+
"affinity_class": acls,
|
| 431 |
+
|
| 432 |
+
"target_embedding": t_list,
|
| 433 |
+
"target_attention_mask": [1] * Lt,
|
| 434 |
+
"target_length": int(Lt),
|
| 435 |
+
|
| 436 |
+
"binder_embedding": b_list,
|
| 437 |
+
"binder_attention_mask": [int(x) for x in b_mask],
|
| 438 |
+
"binder_length": int(Lb),
|
| 439 |
+
|
| 440 |
+
COL_SMI_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
|
| 441 |
+
COL_AFF: np.float32(aff),
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 445 |
+
ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
|
| 446 |
+
ds.save_to_disk(str(out_dir), max_shard_size="1GB")
|
| 447 |
+
return ds
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
# -------------------------
|
| 451 |
+
# SMILES pooled + unpooled (PeptideCLM)
|
| 452 |
+
# -------------------------
|
| 453 |
+
def get_special_ids(tokenizer_obj):
|
| 454 |
+
cand = [
|
| 455 |
+
getattr(tokenizer_obj, "pad_token_id", None),
|
| 456 |
+
getattr(tokenizer_obj, "cls_token_id", None),
|
| 457 |
+
getattr(tokenizer_obj, "sep_token_id", None),
|
| 458 |
+
getattr(tokenizer_obj, "bos_token_id", None),
|
| 459 |
+
getattr(tokenizer_obj, "eos_token_id", None),
|
| 460 |
+
getattr(tokenizer_obj, "mask_token_id", None),
|
| 461 |
+
]
|
| 462 |
+
return sorted({x for x in cand if x is not None})
|
| 463 |
+
|
| 464 |
+
@torch.no_grad()
|
| 465 |
+
def smiles_embed_batch_return_both(batch_sequences, tokenizer_obj, model_roformer, max_length):
|
| 466 |
+
tok = tokenizer_obj(
|
| 467 |
+
batch_sequences,
|
| 468 |
+
return_tensors="pt",
|
| 469 |
+
padding=True,
|
| 470 |
+
truncation=True,
|
| 471 |
+
max_length=max_length,
|
| 472 |
+
)
|
| 473 |
+
input_ids = tok["input_ids"].to(DEVICE)
|
| 474 |
+
attention_mask = tok["attention_mask"].to(DEVICE)
|
| 475 |
+
|
| 476 |
+
outputs = model_roformer(input_ids=input_ids, attention_mask=attention_mask)
|
| 477 |
+
last_hidden = outputs.last_hidden_state # (B, L, H)
|
| 478 |
+
|
| 479 |
+
special_ids = get_special_ids(tokenizer_obj)
|
| 480 |
+
valid = attention_mask.bool()
|
| 481 |
+
if len(special_ids) > 0:
|
| 482 |
+
sid = torch.tensor(special_ids, device=DEVICE, dtype=torch.long)
|
| 483 |
+
if hasattr(torch, "isin"):
|
| 484 |
+
valid = valid & (~torch.isin(input_ids, sid))
|
| 485 |
+
else:
|
| 486 |
+
m = torch.zeros_like(valid)
|
| 487 |
+
for s in special_ids:
|
| 488 |
+
m |= (input_ids == s)
|
| 489 |
+
valid = valid & (~m)
|
| 490 |
+
|
| 491 |
+
valid_f = valid.unsqueeze(-1).float()
|
| 492 |
+
summed = torch.sum(last_hidden * valid_f, dim=1)
|
| 493 |
+
denom = torch.clamp(valid_f.sum(dim=1), min=1e-9)
|
| 494 |
+
pooled = (summed / denom).detach().cpu().numpy()
|
| 495 |
+
|
| 496 |
+
token_emb_list, mask_list, lengths = [], [], []
|
| 497 |
+
for b in range(last_hidden.shape[0]):
|
| 498 |
+
emb = last_hidden[b, valid[b]] # (Li, H)
|
| 499 |
+
token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy())
|
| 500 |
+
li = emb.shape[0]
|
| 501 |
+
lengths.append(int(li))
|
| 502 |
+
mask_list.append(np.ones((li,), dtype=np.int8))
|
| 503 |
+
|
| 504 |
+
return pooled, token_emb_list, mask_list, lengths
|
| 505 |
+
|
| 506 |
+
def smiles_generate_embeddings_batched_both(seqs, tokenizer_obj, model_roformer, batch_size, max_length):
|
| 507 |
+
pooled_all = []
|
| 508 |
+
token_emb_all = []
|
| 509 |
+
mask_all = []
|
| 510 |
+
lengths_all = []
|
| 511 |
+
|
| 512 |
+
for i in pbar(range(0, len(seqs), batch_size)):
|
| 513 |
+
batch = seqs[i:i + batch_size]
|
| 514 |
+
pooled, tok_list, m_list, lens = smiles_embed_batch_return_both(
|
| 515 |
+
batch, tokenizer_obj, model_roformer, max_length
|
| 516 |
+
)
|
| 517 |
+
pooled_all.append(pooled)
|
| 518 |
+
token_emb_all.extend(tok_list)
|
| 519 |
+
mask_all.extend(m_list)
|
| 520 |
+
lengths_all.extend(lens)
|
| 521 |
+
|
| 522 |
+
return np.vstack(pooled_all), token_emb_all, mask_all, lengths_all
|
| 523 |
+
|
| 524 |
+
# -------------------------
|
| 525 |
+
# Target embedding cache (NO extra ESM runs)
|
| 526 |
+
# We will compute target pooled embeddings ONCE from WT view, then reuse for SMILES.
|
| 527 |
+
# -------------------------
|
| 528 |
+
def build_target_cache_from_wt_view(wt_view_train: pd.DataFrame, wt_view_val: pd.DataFrame):
|
| 529 |
+
wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
|
| 530 |
+
wt_model = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
|
| 531 |
+
|
| 532 |
+
# compute target pooled embeddings once
|
| 533 |
+
tgt_wt_train = wt_view_train["target_sequence"].astype(str).tolist()
|
| 534 |
+
tgt_wt_val = wt_view_val["target_sequence"].astype(str).tolist()
|
| 535 |
+
|
| 536 |
+
wt_train_tgt_emb = wt_pooled_embeddings(
|
| 537 |
+
tgt_wt_train, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
|
| 538 |
+
)
|
| 539 |
+
wt_val_tgt_emb = wt_pooled_embeddings(
|
| 540 |
+
tgt_wt_val, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
# build dict: target_sequence -> embedding (float32 array)
|
| 544 |
+
# if duplicates exist, last wins; you can add checks if needed
|
| 545 |
+
train_map = {s: e for s, e in zip(tgt_wt_train, wt_train_tgt_emb)}
|
| 546 |
+
val_map = {s: e for s, e in zip(tgt_wt_val, wt_val_tgt_emb)}
|
| 547 |
+
return wt_tok, wt_model, wt_train_tgt_emb, wt_val_tgt_emb, train_map, val_map
|
| 548 |
+
# -------------------------
|
| 549 |
+
# Main
|
| 550 |
+
# -------------------------
|
| 551 |
+
def main():
|
| 552 |
+
log(f"[INFO] DEVICE: {DEVICE}")
|
| 553 |
+
OUT_ROOT.mkdir(parents=True, exist_ok=True)
|
| 554 |
+
|
| 555 |
+
# 1) Load
|
| 556 |
+
with section("load csv + dedup"):
|
| 557 |
+
df = pd.read_csv(CSV_PATH)
|
| 558 |
+
for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]:
|
| 559 |
+
if c in df.columns:
|
| 560 |
+
df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x)
|
| 561 |
+
|
| 562 |
+
# Dedup on the full identity tuple you want
|
| 563 |
+
DEDUP_COLS = [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]
|
| 564 |
+
df = df.drop_duplicates(subset=DEDUP_COLS).reset_index(drop=True)
|
| 565 |
+
|
| 566 |
+
print("Rows after dedup on", DEDUP_COLS, ":", len(df))
|
| 567 |
+
|
| 568 |
+
need = [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM]
|
| 569 |
+
missing = [c for c in need if c not in df.columns]
|
| 570 |
+
if missing:
|
| 571 |
+
raise ValueError(f"Missing required columns: {missing}")
|
| 572 |
+
|
| 573 |
+
# numeric affinity for both branches
|
| 574 |
+
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
|
| 575 |
+
|
| 576 |
+
# 2) Build WT subset + SMILES subset separately (NO global dropping)
|
| 577 |
+
with section("prepare wt/smiles subsets"):
|
| 578 |
+
# WT: requires a canonical peptide sequence (no X) + affinity
|
| 579 |
+
df_wt = df.copy()
|
| 580 |
+
df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip()
|
| 581 |
+
df_wt = df_wt.dropna(subset=[COL_AFF]).reset_index(drop=True)
|
| 582 |
+
df_wt = df_wt[df_wt["wt_sequence"].notna() & (df_wt["wt_sequence"] != "")]
|
| 583 |
+
df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)].reset_index(drop=True)
|
| 584 |
+
|
| 585 |
+
# SMILES: requires affinity + a usable picked SMILES (UAA->REACT, else->Fasta2SMILES)
|
| 586 |
+
df_smi = df.copy()
|
| 587 |
+
df_smi = df_smi.dropna(subset=[COL_AFF]).reset_index(drop=True)
|
| 588 |
+
df_smi = df_smi[
|
| 589 |
+
pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna()
|
| 590 |
+
].reset_index(drop=True) # empty iptm means sth wrong with their smiles sequenc
|
| 591 |
+
|
| 592 |
+
is_uaa = df_smi[COL_SEQ2].astype(str).str.contains("X", case=False, na=False)
|
| 593 |
+
df_smi["smiles_sequence"] = np.where(is_uaa, df_smi[COL_REACT], df_smi[COL_F2S])
|
| 594 |
+
df_smi["smiles_sequence"] = df_smi["smiles_sequence"].astype(str).str.strip()
|
| 595 |
+
df_smi = df_smi[df_smi["smiles_sequence"].notna() & (df_smi["smiles_sequence"] != "")]
|
| 596 |
+
df_smi = df_smi[~df_smi["smiles_sequence"].isin(["nan", "None"])].reset_index(drop=True)
|
| 597 |
+
|
| 598 |
+
log(f"[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)} (after per-branch filtering)")
|
| 599 |
+
|
| 600 |
+
# 3) Split separately (different sizes and memberships are expected)
|
| 601 |
+
with section("split wt and smiles separately"):
|
| 602 |
+
df_wt2 = make_distribution_matched_split(df_wt)
|
| 603 |
+
df_smi2 = make_distribution_matched_split(df_smi)
|
| 604 |
+
|
| 605 |
+
# save split tables
|
| 606 |
+
wt_split_csv = OUT_ROOT / "binding_affinity_wt_meta_with_split.csv"
|
| 607 |
+
smi_split_csv = OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv"
|
| 608 |
+
df_wt2.to_csv(wt_split_csv, index=False)
|
| 609 |
+
df_smi2.to_csv(smi_split_csv, index=False)
|
| 610 |
+
log(f"Saved WT split meta: {wt_split_csv}")
|
| 611 |
+
log(f"Saved SMILES split meta: {smi_split_csv}")
|
| 612 |
+
|
| 613 |
+
# lightweight double-check (one-line)
|
| 614 |
+
verify_split_before_embedding(
|
| 615 |
+
df2=df_wt2,
|
| 616 |
+
affinity_col=COL_AFF,
|
| 617 |
+
split_col="split",
|
| 618 |
+
seq_col="wt_sequence",
|
| 619 |
+
iptm_col=COL_WT_IPTM,
|
| 620 |
+
aff_class_col="affinity_class",
|
| 621 |
+
aff_bins=AFFINITY_Q_BINS,
|
| 622 |
+
save_report_prefix=str(OUT_ROOT / "wt_split_doublecheck_report"),
|
| 623 |
+
verbose=False,
|
| 624 |
+
)
|
| 625 |
+
verify_split_before_embedding(
|
| 626 |
+
df2=df_smi2,
|
| 627 |
+
affinity_col=COL_AFF,
|
| 628 |
+
split_col="split",
|
| 629 |
+
seq_col="smiles_sequence",
|
| 630 |
+
iptm_col=COL_SMI_IPTM,
|
| 631 |
+
aff_class_col="affinity_class",
|
| 632 |
+
aff_bins=AFFINITY_Q_BINS,
|
| 633 |
+
save_report_prefix=str(OUT_ROOT / "smiles_split_doublecheck_report"),
|
| 634 |
+
verbose=False,
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
# Prepare split views
|
| 638 |
+
def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame:
|
| 639 |
+
out = df_in.copy()
|
| 640 |
+
out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() # <-- NEW
|
| 641 |
+
out["sequence"] = out[binder_seq_col].astype(str).str.strip() # binder
|
| 642 |
+
out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce")
|
| 643 |
+
out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce")
|
| 644 |
+
out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce")
|
| 645 |
+
out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True)
|
| 646 |
+
return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]]
|
| 647 |
+
|
| 648 |
+
wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM)
|
| 649 |
+
smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM)
|
| 650 |
+
|
| 651 |
+
# -------------------------
|
| 652 |
+
# Split views
|
| 653 |
+
# -------------------------
|
| 654 |
+
wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True)
|
| 655 |
+
wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True)
|
| 656 |
+
smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True)
|
| 657 |
+
smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True)
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
# =========================
|
| 661 |
+
# TARGET pooled embeddings (ESM) — SEPARATE per branch
|
| 662 |
+
# =========================
|
| 663 |
+
with section("TARGET pooled embeddings (ESM) — WT + SMILES separately"):
|
| 664 |
+
wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
|
| 665 |
+
wt_esm = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
|
| 666 |
+
|
| 667 |
+
# ---- WT targets ----
|
| 668 |
+
wt_train_tgt_emb = wt_pooled_embeddings(
|
| 669 |
+
wt_train["target_sequence"].astype(str).str.strip().tolist(),
|
| 670 |
+
wt_tok, wt_esm,
|
| 671 |
+
batch_size=WT_BATCH,
|
| 672 |
+
max_length=WT_MAX_LEN,
|
| 673 |
+
).astype(np.float32)
|
| 674 |
+
|
| 675 |
+
wt_val_tgt_emb = wt_pooled_embeddings(
|
| 676 |
+
wt_val["target_sequence"].astype(str).str.strip().tolist(),
|
| 677 |
+
wt_tok, wt_esm,
|
| 678 |
+
batch_size=WT_BATCH,
|
| 679 |
+
max_length=WT_MAX_LEN,
|
| 680 |
+
).astype(np.float32)
|
| 681 |
+
|
| 682 |
+
# ---- SMILES targets (independent; may include UAA-only targets) ----
|
| 683 |
+
smi_train_tgt_emb = wt_pooled_embeddings(
|
| 684 |
+
smi_train["target_sequence"].astype(str).str.strip().tolist(),
|
| 685 |
+
wt_tok, wt_esm,
|
| 686 |
+
batch_size=WT_BATCH,
|
| 687 |
+
max_length=WT_MAX_LEN,
|
| 688 |
+
).astype(np.float32)
|
| 689 |
+
|
| 690 |
+
smi_val_tgt_emb = wt_pooled_embeddings(
|
| 691 |
+
smi_val["target_sequence"].astype(str).str.strip().tolist(),
|
| 692 |
+
wt_tok, wt_esm,
|
| 693 |
+
batch_size=WT_BATCH,
|
| 694 |
+
max_length=WT_MAX_LEN,
|
| 695 |
+
).astype(np.float32)
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
# =========================
|
| 699 |
+
# WT pooled binder embeddings (binder = WT peptide)
|
| 700 |
+
# =========================
|
| 701 |
+
with section("WT pooled binder embeddings + save"):
|
| 702 |
+
wt_train_emb = wt_pooled_embeddings(
|
| 703 |
+
wt_train["sequence"].astype(str).str.strip().tolist(),
|
| 704 |
+
wt_tok, wt_esm,
|
| 705 |
+
batch_size=WT_BATCH,
|
| 706 |
+
max_length=WT_MAX_LEN,
|
| 707 |
+
).astype(np.float32)
|
| 708 |
+
|
| 709 |
+
wt_val_emb = wt_pooled_embeddings(
|
| 710 |
+
wt_val["sequence"].astype(str).str.strip().tolist(),
|
| 711 |
+
wt_tok, wt_esm,
|
| 712 |
+
batch_size=WT_BATCH,
|
| 713 |
+
max_length=WT_MAX_LEN,
|
| 714 |
+
).astype(np.float32)
|
| 715 |
+
|
| 716 |
+
wt_train_ds = Dataset.from_dict({
|
| 717 |
+
"target_sequence": wt_train["target_sequence"].tolist(),
|
| 718 |
+
"sequence": wt_train["sequence"].tolist(),
|
| 719 |
+
"label": wt_train["label"].astype(float).tolist(),
|
| 720 |
+
"target_embedding": wt_train_tgt_emb,
|
| 721 |
+
"embedding": wt_train_emb,
|
| 722 |
+
COL_WT_IPTM: wt_train[COL_WT_IPTM].astype(float).tolist(),
|
| 723 |
+
COL_AFF: wt_train[COL_AFF].astype(float).tolist(),
|
| 724 |
+
"affinity_class": wt_train["affinity_class"].tolist(),
|
| 725 |
+
})
|
| 726 |
+
|
| 727 |
+
wt_val_ds = Dataset.from_dict({
|
| 728 |
+
"target_sequence": wt_val["target_sequence"].tolist(),
|
| 729 |
+
"sequence": wt_val["sequence"].tolist(),
|
| 730 |
+
"label": wt_val["label"].astype(float).tolist(),
|
| 731 |
+
"target_embedding": wt_val_tgt_emb,
|
| 732 |
+
"embedding": wt_val_emb,
|
| 733 |
+
COL_WT_IPTM: wt_val[COL_WT_IPTM].astype(float).tolist(),
|
| 734 |
+
COL_AFF: wt_val[COL_AFF].astype(float).tolist(),
|
| 735 |
+
"affinity_class": wt_val["affinity_class"].tolist(),
|
| 736 |
+
})
|
| 737 |
+
|
| 738 |
+
wt_pooled_dd = DatasetDict({"train": wt_train_ds, "val": wt_val_ds})
|
| 739 |
+
wt_pooled_out = OUT_ROOT / "pair_wt_wt_pooled"
|
| 740 |
+
wt_pooled_dd.save_to_disk(str(wt_pooled_out))
|
| 741 |
+
log(f"Saved WT pooled -> {wt_pooled_out}")
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
# =========================
|
| 745 |
+
# SMILES pooled binder embeddings (binder = SMILES via PeptideCLM)
|
| 746 |
+
# =========================
|
| 747 |
+
with section("SMILES pooled binder embeddings + save"):
|
| 748 |
+
smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
|
| 749 |
+
smi_roformer = (
|
| 750 |
+
AutoModelForMaskedLM
|
| 751 |
+
.from_pretrained(SMI_MODEL_NAME)
|
| 752 |
+
.roformer
|
| 753 |
+
.to(DEVICE)
|
| 754 |
+
.eval()
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
smi_train_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
|
| 758 |
+
smi_train["sequence"].astype(str).str.strip().tolist(),
|
| 759 |
+
smi_tok, smi_roformer,
|
| 760 |
+
batch_size=SMI_BATCH,
|
| 761 |
+
max_length=SMI_MAX_LEN,
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
smi_val_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
|
| 765 |
+
smi_val["sequence"].astype(str).str.strip().tolist(),
|
| 766 |
+
smi_tok, smi_roformer,
|
| 767 |
+
batch_size=SMI_BATCH,
|
| 768 |
+
max_length=SMI_MAX_LEN,
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
smi_train_ds = Dataset.from_dict({
|
| 772 |
+
"target_sequence": smi_train["target_sequence"].tolist(),
|
| 773 |
+
"sequence": smi_train["sequence"].tolist(),
|
| 774 |
+
"label": smi_train["label"].astype(float).tolist(),
|
| 775 |
+
"target_embedding": smi_train_tgt_emb,
|
| 776 |
+
"embedding": smi_train_pooled.astype(np.float32),
|
| 777 |
+
COL_SMI_IPTM: smi_train[COL_SMI_IPTM].astype(float).tolist(),
|
| 778 |
+
COL_AFF: smi_train[COL_AFF].astype(float).tolist(),
|
| 779 |
+
"affinity_class": smi_train["affinity_class"].tolist(),
|
| 780 |
+
})
|
| 781 |
+
|
| 782 |
+
smi_val_ds = Dataset.from_dict({
|
| 783 |
+
"target_sequence": smi_val["target_sequence"].tolist(),
|
| 784 |
+
"sequence": smi_val["sequence"].tolist(),
|
| 785 |
+
"label": smi_val["label"].astype(float).tolist(),
|
| 786 |
+
"target_embedding": smi_val_tgt_emb,
|
| 787 |
+
"embedding": smi_val_pooled.astype(np.float32),
|
| 788 |
+
COL_SMI_IPTM: smi_val[COL_SMI_IPTM].astype(float).tolist(),
|
| 789 |
+
COL_AFF: smi_val[COL_AFF].astype(float).tolist(),
|
| 790 |
+
"affinity_class": smi_val["affinity_class"].tolist(),
|
| 791 |
+
})
|
| 792 |
+
|
| 793 |
+
smi_pooled_dd = DatasetDict({"train": smi_train_ds, "val": smi_val_ds})
|
| 794 |
+
smi_pooled_out = OUT_ROOT / "pair_wt_smiles_pooled"
|
| 795 |
+
smi_pooled_dd.save_to_disk(str(smi_pooled_out))
|
| 796 |
+
log(f"Saved SMILES pooled -> {smi_pooled_out}")
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
# =========================
|
| 800 |
+
# WT unpooled paired (ESM target + ESM binder) + save
|
| 801 |
+
# =========================
|
| 802 |
+
with section("WT unpooled paired embeddings + save"):
|
| 803 |
+
wt_tok_unpooled = wt_tok # reuse tokenizer
|
| 804 |
+
wt_esm_unpooled = wt_esm # reuse model
|
| 805 |
+
|
| 806 |
+
wt_unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled"
|
| 807 |
+
wt_unpooled_dd = DatasetDict({
|
| 808 |
+
"train": build_wt_unpooled_dataset(wt_train, wt_unpooled_out / "train",
|
| 809 |
+
wt_tok_unpooled, wt_esm_unpooled),
|
| 810 |
+
"val": build_wt_unpooled_dataset(wt_val, wt_unpooled_out / "val",
|
| 811 |
+
wt_tok_unpooled, wt_esm_unpooled),
|
| 812 |
+
})
|
| 813 |
+
# (Optional) also save as DatasetDict root if you want a single load_from_disk path:
|
| 814 |
+
wt_unpooled_dd.save_to_disk(str(wt_unpooled_out))
|
| 815 |
+
log(f"Saved WT unpooled -> {wt_unpooled_out}")
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
# =========================
|
| 819 |
+
# SMILES unpooled paired (ESM target + PeptideCLM binder) + save
|
| 820 |
+
# =========================
|
| 821 |
+
with section("SMILES unpooled paired embeddings + save"):
|
| 822 |
+
# reuse already-loaded smi_tok/smi_roformer from pooled section if still in scope;
|
| 823 |
+
# otherwise re-init here:
|
| 824 |
+
# smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
|
| 825 |
+
# smi_roformer = AutoModelForMaskedLM.from_pretrained(SMI_MODEL_NAME).roformer.to(DEVICE).eval()
|
| 826 |
+
|
| 827 |
+
smi_unpooled_out = OUT_ROOT / "pair_wt_smiles_unpooled"
|
| 828 |
+
smi_unpooled_dd = DatasetDict({
|
| 829 |
+
"train": build_smiles_unpooled_paired_dataset(
|
| 830 |
+
smi_train, smi_unpooled_out / "train",
|
| 831 |
+
wt_tok, wt_esm,
|
| 832 |
+
smi_tok, smi_roformer
|
| 833 |
+
),
|
| 834 |
+
"val": build_smiles_unpooled_paired_dataset(
|
| 835 |
+
smi_val, smi_unpooled_out / "val",
|
| 836 |
+
wt_tok, wt_esm,
|
| 837 |
+
smi_tok, smi_roformer
|
| 838 |
+
),
|
| 839 |
+
})
|
| 840 |
+
smi_unpooled_dd.save_to_disk(str(smi_unpooled_out))
|
| 841 |
+
log(f"Saved SMILES unpooled -> {smi_unpooled_out}")
|
| 842 |
+
|
| 843 |
+
log(f"\n[DONE] All datasets saved under: {OUT_ROOT}")
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
if __name__ == "__main__":
|
| 847 |
+
main()
|
training_classifiers/.ipynb_checkpoints/binding_training-checkpoint.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
import optuna
|
| 8 |
+
from datasets import load_from_disk, DatasetDict
|
| 9 |
+
from scipy.stats import spearmanr
|
| 10 |
+
from lightning.pytorch import seed_everything
|
| 11 |
+
seed_everything(1986)
|
| 12 |
+
|
| 13 |
+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
| 17 |
+
rho = spearmanr(y_true, y_pred).correlation
|
| 18 |
+
if rho is None or np.isnan(rho):
|
| 19 |
+
return 0.0
|
| 20 |
+
return float(rho)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# -----------------------------
|
| 24 |
+
# Affinity class thresholds (final spec)
|
| 25 |
+
# High >= 9 ; Moderate 7-9 ; Low < 7
|
| 26 |
+
# 0=High, 1=Moderate, 2=Low
|
| 27 |
+
# -----------------------------
|
| 28 |
+
def affinity_to_class_tensor(y: torch.Tensor) -> torch.Tensor:
|
| 29 |
+
high = y >= 9.0
|
| 30 |
+
low = y < 7.0
|
| 31 |
+
mid = ~(high | low)
|
| 32 |
+
cls = torch.zeros_like(y, dtype=torch.long)
|
| 33 |
+
cls[mid] = 1
|
| 34 |
+
cls[low] = 2
|
| 35 |
+
return cls
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# -----------------------------
|
| 39 |
+
# Load paired DatasetDict
|
| 40 |
+
# -----------------------------
|
| 41 |
+
def load_split_paired(path: str):
|
| 42 |
+
dd = load_from_disk(path)
|
| 43 |
+
if not isinstance(dd, DatasetDict):
|
| 44 |
+
raise ValueError(f"Expected DatasetDict at {path}")
|
| 45 |
+
if "train" not in dd or "val" not in dd:
|
| 46 |
+
raise ValueError(f"DatasetDict missing train/val at {path}")
|
| 47 |
+
return dd["train"], dd["val"]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# -----------------------------
|
| 51 |
+
# Collate: pooled paired
|
| 52 |
+
# -----------------------------
|
| 53 |
+
def collate_pair_pooled(batch):
|
| 54 |
+
Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32) # (B,Ht)
|
| 55 |
+
Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32) # (B,Hb)
|
| 56 |
+
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
|
| 57 |
+
return Pt, Pb, y
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# -----------------------------
|
| 61 |
+
# Collate: unpooled paired
|
| 62 |
+
# -----------------------------
|
| 63 |
+
def collate_pair_unpooled(batch):
|
| 64 |
+
B = len(batch)
|
| 65 |
+
Ht = len(batch[0]["target_embedding"][0])
|
| 66 |
+
Hb = len(batch[0]["binder_embedding"][0])
|
| 67 |
+
Lt_max = max(int(x["target_length"]) for x in batch)
|
| 68 |
+
Lb_max = max(int(x["binder_length"]) for x in batch)
|
| 69 |
+
|
| 70 |
+
Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32)
|
| 71 |
+
Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32)
|
| 72 |
+
Mt = torch.zeros(B, Lt_max, dtype=torch.bool)
|
| 73 |
+
Mb = torch.zeros(B, Lb_max, dtype=torch.bool)
|
| 74 |
+
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
|
| 75 |
+
|
| 76 |
+
for i, x in enumerate(batch):
|
| 77 |
+
t = torch.tensor(x["target_embedding"], dtype=torch.float32)
|
| 78 |
+
b = torch.tensor(x["binder_embedding"], dtype=torch.float32)
|
| 79 |
+
lt, lb = t.shape[0], b.shape[0]
|
| 80 |
+
Pt[i, :lt] = t
|
| 81 |
+
Pb[i, :lb] = b
|
| 82 |
+
Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool)
|
| 83 |
+
Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool)
|
| 84 |
+
|
| 85 |
+
return Pt, Mt, Pb, Mb, y
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# -----------------------------
|
| 89 |
+
# Cross-attention models
|
| 90 |
+
# -----------------------------
|
| 91 |
+
class CrossAttnPooled(nn.Module):
|
| 92 |
+
"""
|
| 93 |
+
pooled vectors -> treat as single-token sequences for cross attention
|
| 94 |
+
"""
|
| 95 |
+
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
|
| 98 |
+
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
|
| 99 |
+
|
| 100 |
+
self.layers = nn.ModuleList([])
|
| 101 |
+
for _ in range(n_layers):
|
| 102 |
+
self.layers.append(nn.ModuleDict({
|
| 103 |
+
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
|
| 104 |
+
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
|
| 105 |
+
"n1t": nn.LayerNorm(hidden),
|
| 106 |
+
"n2t": nn.LayerNorm(hidden),
|
| 107 |
+
"n1b": nn.LayerNorm(hidden),
|
| 108 |
+
"n2b": nn.LayerNorm(hidden),
|
| 109 |
+
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 110 |
+
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 111 |
+
}))
|
| 112 |
+
|
| 113 |
+
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
|
| 114 |
+
self.reg = nn.Linear(hidden, 1)
|
| 115 |
+
self.cls = nn.Linear(hidden, 3)
|
| 116 |
+
|
| 117 |
+
def forward(self, t_vec, b_vec):
|
| 118 |
+
# (B,Ht),(B,Hb)
|
| 119 |
+
t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
|
| 120 |
+
b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
|
| 121 |
+
|
| 122 |
+
for L in self.layers:
|
| 123 |
+
t_attn, _ = L["attn_tb"](t, b, b)
|
| 124 |
+
t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
|
| 125 |
+
t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
|
| 126 |
+
|
| 127 |
+
b_attn, _ = L["attn_bt"](b, t, t)
|
| 128 |
+
b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
|
| 129 |
+
b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
|
| 130 |
+
|
| 131 |
+
t0 = t[0]
|
| 132 |
+
b0 = b[0]
|
| 133 |
+
z = torch.cat([t0, b0], dim=-1)
|
| 134 |
+
h = self.shared(z)
|
| 135 |
+
return self.reg(h).squeeze(-1), self.cls(h)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class CrossAttnUnpooled(nn.Module):
|
| 139 |
+
"""
|
| 140 |
+
token sequences with masks; alternating cross attention.
|
| 141 |
+
"""
|
| 142 |
+
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
|
| 145 |
+
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
|
| 146 |
+
|
| 147 |
+
self.layers = nn.ModuleList([])
|
| 148 |
+
for _ in range(n_layers):
|
| 149 |
+
self.layers.append(nn.ModuleDict({
|
| 150 |
+
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 151 |
+
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 152 |
+
"n1t": nn.LayerNorm(hidden),
|
| 153 |
+
"n2t": nn.LayerNorm(hidden),
|
| 154 |
+
"n1b": nn.LayerNorm(hidden),
|
| 155 |
+
"n2b": nn.LayerNorm(hidden),
|
| 156 |
+
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 157 |
+
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 158 |
+
}))
|
| 159 |
+
|
| 160 |
+
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
|
| 161 |
+
self.reg = nn.Linear(hidden, 1)
|
| 162 |
+
self.cls = nn.Linear(hidden, 3)
|
| 163 |
+
|
| 164 |
+
def masked_mean(self, X, M):
|
| 165 |
+
Mf = M.unsqueeze(-1).float()
|
| 166 |
+
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 167 |
+
return (X * Mf).sum(dim=1) / denom
|
| 168 |
+
|
| 169 |
+
def forward(self, T, Mt, B, Mb):
|
| 170 |
+
# T:(B,Lt,Ht), Mt:(B,Lt) ; B:(B,Lb,Hb), Mb:(B,Lb)
|
| 171 |
+
T = self.t_proj(T)
|
| 172 |
+
Bx = self.b_proj(B)
|
| 173 |
+
|
| 174 |
+
kp_t = ~Mt # key_padding_mask True = pad
|
| 175 |
+
kp_b = ~Mb
|
| 176 |
+
|
| 177 |
+
for L in self.layers:
|
| 178 |
+
# T attends to B
|
| 179 |
+
T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
|
| 180 |
+
T = L["n1t"](T + T_attn)
|
| 181 |
+
T = L["n2t"](T + L["fft"](T))
|
| 182 |
+
|
| 183 |
+
# B attends to T
|
| 184 |
+
B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
|
| 185 |
+
Bx = L["n1b"](Bx + B_attn)
|
| 186 |
+
Bx = L["n2b"](Bx + L["ffb"](Bx))
|
| 187 |
+
|
| 188 |
+
t_pool = self.masked_mean(T, Mt)
|
| 189 |
+
b_pool = self.masked_mean(Bx, Mb)
|
| 190 |
+
z = torch.cat([t_pool, b_pool], dim=-1)
|
| 191 |
+
h = self.shared(z)
|
| 192 |
+
return self.reg(h).squeeze(-1), self.cls(h)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# -----------------------------
|
| 196 |
+
# Train/eval
|
| 197 |
+
# -----------------------------
|
| 198 |
+
@torch.no_grad()
|
| 199 |
+
def eval_spearman_pooled(model, loader):
|
| 200 |
+
model.eval()
|
| 201 |
+
ys, ps = [], []
|
| 202 |
+
for t, b, y in loader:
|
| 203 |
+
t = t.to(DEVICE, non_blocking=True)
|
| 204 |
+
b = b.to(DEVICE, non_blocking=True)
|
| 205 |
+
pred, _ = model(t, b)
|
| 206 |
+
ys.append(y.numpy())
|
| 207 |
+
ps.append(pred.detach().cpu().numpy())
|
| 208 |
+
return safe_spearmanr(np.concatenate(ys), np.concatenate(ps))
|
| 209 |
+
|
| 210 |
+
@torch.no_grad()
|
| 211 |
+
def eval_spearman_unpooled(model, loader):
|
| 212 |
+
model.eval()
|
| 213 |
+
ys, ps = [], []
|
| 214 |
+
for T, Mt, B, Mb, y in loader:
|
| 215 |
+
T = T.to(DEVICE, non_blocking=True)
|
| 216 |
+
Mt = Mt.to(DEVICE, non_blocking=True)
|
| 217 |
+
B = B.to(DEVICE, non_blocking=True)
|
| 218 |
+
Mb = Mb.to(DEVICE, non_blocking=True)
|
| 219 |
+
pred, _ = model(T, Mt, B, Mb)
|
| 220 |
+
ys.append(y.numpy())
|
| 221 |
+
ps.append(pred.detach().cpu().numpy())
|
| 222 |
+
return safe_spearmanr(np.concatenate(ys), np.concatenate(ps))
|
| 223 |
+
|
| 224 |
+
def train_one_epoch_pooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0):
|
| 225 |
+
model.train()
|
| 226 |
+
for t, b, y in loader:
|
| 227 |
+
t = t.to(DEVICE, non_blocking=True)
|
| 228 |
+
b = b.to(DEVICE, non_blocking=True)
|
| 229 |
+
y = y.to(DEVICE, non_blocking=True)
|
| 230 |
+
y_cls = affinity_to_class_tensor(y)
|
| 231 |
+
|
| 232 |
+
opt.zero_grad(set_to_none=True)
|
| 233 |
+
pred, logits = model(t, b)
|
| 234 |
+
L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls)
|
| 235 |
+
L.backward()
|
| 236 |
+
if clip is not None:
|
| 237 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
|
| 238 |
+
opt.step()
|
| 239 |
+
|
| 240 |
+
def train_one_epoch_unpooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0):
|
| 241 |
+
model.train()
|
| 242 |
+
for T, Mt, B, Mb, y in loader:
|
| 243 |
+
T = T.to(DEVICE, non_blocking=True)
|
| 244 |
+
Mt = Mt.to(DEVICE, non_blocking=True)
|
| 245 |
+
B = B.to(DEVICE, non_blocking=True)
|
| 246 |
+
Mb = Mb.to(DEVICE, non_blocking=True)
|
| 247 |
+
y = y.to(DEVICE, non_blocking=True)
|
| 248 |
+
y_cls = affinity_to_class_tensor(y)
|
| 249 |
+
|
| 250 |
+
opt.zero_grad(set_to_none=True)
|
| 251 |
+
pred, logits = model(T, Mt, B, Mb)
|
| 252 |
+
L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls)
|
| 253 |
+
L.backward()
|
| 254 |
+
if clip is not None:
|
| 255 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
|
| 256 |
+
opt.step()
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# -----------------------------
|
| 260 |
+
# Optuna objective
|
| 261 |
+
# -----------------------------
|
| 262 |
+
def objective_crossattn(trial: optuna.Trial, mode: str, train_ds, val_ds) -> float:
|
| 263 |
+
lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
|
| 264 |
+
wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True)
|
| 265 |
+
dropout = trial.suggest_float("dropout", 0.0, 0.4)
|
| 266 |
+
hidden = trial.suggest_categorical("hidden_dim", [256, 384, 512, 768])
|
| 267 |
+
n_heads = trial.suggest_categorical("n_heads", [4, 8])
|
| 268 |
+
n_layers = trial.suggest_int("n_layers", 1, 4)
|
| 269 |
+
cls_w = trial.suggest_float("cls_weight", 0.1, 2.0, log=True)
|
| 270 |
+
batch = trial.suggest_categorical("batch_size", [16, 32, 64, 128])
|
| 271 |
+
|
| 272 |
+
# infer dims from first row
|
| 273 |
+
if mode == "pooled":
|
| 274 |
+
Ht = len(train_ds[0]["target_embedding"])
|
| 275 |
+
Hb = len(train_ds[0]["binder_embedding"])
|
| 276 |
+
collate = collate_pair_pooled
|
| 277 |
+
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 278 |
+
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 279 |
+
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 280 |
+
eval_fn = eval_spearman_pooled
|
| 281 |
+
train_fn = train_one_epoch_pooled
|
| 282 |
+
|
| 283 |
+
else:
|
| 284 |
+
Ht = len(train_ds[0]["target_embedding"][0])
|
| 285 |
+
Hb = len(train_ds[0]["binder_embedding"][0])
|
| 286 |
+
collate = collate_pair_unpooled
|
| 287 |
+
model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 288 |
+
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 289 |
+
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 290 |
+
eval_fn = eval_spearman_unpooled
|
| 291 |
+
train_fn = train_one_epoch_unpooled
|
| 292 |
+
|
| 293 |
+
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
|
| 294 |
+
loss_reg = nn.MSELoss()
|
| 295 |
+
loss_cls = nn.CrossEntropyLoss()
|
| 296 |
+
|
| 297 |
+
best = -1e9
|
| 298 |
+
bad = 0
|
| 299 |
+
patience = 10
|
| 300 |
+
|
| 301 |
+
for ep in range(1, 61):
|
| 302 |
+
train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w)
|
| 303 |
+
rho = eval_fn(model, val_loader)
|
| 304 |
+
|
| 305 |
+
trial.report(rho, ep)
|
| 306 |
+
if trial.should_prune():
|
| 307 |
+
raise optuna.TrialPruned()
|
| 308 |
+
|
| 309 |
+
if rho > best + 1e-6:
|
| 310 |
+
best = rho
|
| 311 |
+
bad = 0
|
| 312 |
+
else:
|
| 313 |
+
bad += 1
|
| 314 |
+
if bad >= patience:
|
| 315 |
+
break
|
| 316 |
+
|
| 317 |
+
return float(best)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# -----------------------------
|
| 321 |
+
# Run: optuna + refit best
|
| 322 |
+
# -----------------------------
|
| 323 |
+
def run(dataset_path: str, out_dir: str, mode: str, n_trials: int = 50):
|
| 324 |
+
out_dir = Path(out_dir)
|
| 325 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 326 |
+
|
| 327 |
+
train_ds, val_ds = load_split_paired(dataset_path)
|
| 328 |
+
print(f"[Data] Train={len(train_ds)} Val={len(val_ds)} | mode={mode}")
|
| 329 |
+
|
| 330 |
+
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
|
| 331 |
+
study.optimize(lambda t: objective_crossattn(t, mode, train_ds, val_ds), n_trials=n_trials)
|
| 332 |
+
|
| 333 |
+
study.trials_dataframe().to_csv(out_dir / "optuna_trials.csv", index=False)
|
| 334 |
+
best = study.best_trial
|
| 335 |
+
best_params = dict(best.params)
|
| 336 |
+
|
| 337 |
+
# refit longer
|
| 338 |
+
lr = float(best_params["lr"])
|
| 339 |
+
wd = float(best_params["weight_decay"])
|
| 340 |
+
dropout = float(best_params["dropout"])
|
| 341 |
+
hidden = int(best_params["hidden_dim"])
|
| 342 |
+
n_heads = int(best_params["n_heads"])
|
| 343 |
+
n_layers = int(best_params["n_layers"])
|
| 344 |
+
cls_w = float(best_params["cls_weight"])
|
| 345 |
+
batch = int(best_params["batch_size"])
|
| 346 |
+
|
| 347 |
+
loss_reg = nn.MSELoss()
|
| 348 |
+
loss_cls = nn.CrossEntropyLoss()
|
| 349 |
+
|
| 350 |
+
if mode == "pooled":
|
| 351 |
+
Ht = len(train_ds[0]["target_embedding"])
|
| 352 |
+
Hb = len(train_ds[0]["binder_embedding"])
|
| 353 |
+
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 354 |
+
collate = collate_pair_pooled
|
| 355 |
+
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 356 |
+
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 357 |
+
eval_fn = eval_spearman_pooled
|
| 358 |
+
train_fn = train_one_epoch_pooled
|
| 359 |
+
else:
|
| 360 |
+
Ht = len(train_ds[0]["target_embedding"][0])
|
| 361 |
+
Hb = len(train_ds[0]["binder_embedding"][0])
|
| 362 |
+
model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 363 |
+
collate = collate_pair_unpooled
|
| 364 |
+
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 365 |
+
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 366 |
+
eval_fn = eval_spearman_unpooled
|
| 367 |
+
train_fn = train_one_epoch_unpooled
|
| 368 |
+
|
| 369 |
+
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
|
| 370 |
+
|
| 371 |
+
best_rho = -1e9
|
| 372 |
+
bad = 0
|
| 373 |
+
patience = 20
|
| 374 |
+
best_state = None
|
| 375 |
+
|
| 376 |
+
for ep in range(1, 201):
|
| 377 |
+
train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w)
|
| 378 |
+
rho = eval_fn(model, val_loader)
|
| 379 |
+
|
| 380 |
+
if rho > best_rho + 1e-6:
|
| 381 |
+
best_rho = rho
|
| 382 |
+
bad = 0
|
| 383 |
+
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
| 384 |
+
else:
|
| 385 |
+
bad += 1
|
| 386 |
+
if bad >= patience:
|
| 387 |
+
break
|
| 388 |
+
|
| 389 |
+
if best_state is not None:
|
| 390 |
+
model.load_state_dict(best_state)
|
| 391 |
+
|
| 392 |
+
# save
|
| 393 |
+
torch.save({"mode": mode, "best_params": best_params, "state_dict": model.state_dict()}, out_dir / "best_model.pt")
|
| 394 |
+
with open(out_dir / "best_params.json", "w") as f:
|
| 395 |
+
json.dump(best_params, f, indent=2)
|
| 396 |
+
|
| 397 |
+
print(f"[DONE] {out_dir} | best_optuna_rho={study.best_value:.4f} | refit_best_rho={best_rho:.4f}")
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
if __name__ == "__main__":
|
| 401 |
+
import argparse
|
| 402 |
+
ap = argparse.ArgumentParser()
|
| 403 |
+
ap.add_argument("--dataset_path", type=str, required=True, help="Paired DatasetDict path (pair_*)")
|
| 404 |
+
ap.add_argument("--mode", type=str, choices=["pooled", "unpooled"], required=True)
|
| 405 |
+
ap.add_argument("--out_dir", type=str, required=True)
|
| 406 |
+
ap.add_argument("--n_trials", type=int, default=50)
|
| 407 |
+
args = ap.parse_args()
|
| 408 |
+
|
| 409 |
+
run(
|
| 410 |
+
dataset_path=args.dataset_path,
|
| 411 |
+
out_dir=args.out_dir,
|
| 412 |
+
mode=args.mode,
|
| 413 |
+
n_trials=args.n_trials,
|
| 414 |
+
)
|
training_classifiers/.ipynb_checkpoints/binding_wt-checkpoint.bash
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=b-data
|
| 3 |
+
#SBATCH --partition=dgx-b200
|
| 4 |
+
#SBATCH --gpus=1
|
| 5 |
+
#SBATCH --cpus-per-task=10
|
| 6 |
+
#SBATCH --mem=100G
|
| 7 |
+
#SBATCH --time=48:00:00
|
| 8 |
+
#SBATCH --output=%x_%j.out
|
| 9 |
+
|
| 10 |
+
HOME_LOC=/vast/projects/pranam/lab/yz927
|
| 11 |
+
SCRIPT_LOC=$HOME_LOC/projects/Classifier_Weight/training_classifiers
|
| 12 |
+
DATA_LOC=$HOME_LOC/projects/Classifier_Weight/training_data_cleaned
|
| 13 |
+
OBJECTIVE='binding_affinity'
|
| 14 |
+
WT='smiles' #wt/smiles
|
| 15 |
+
STATUS='pooled' #pooled/unpooled
|
| 16 |
+
DATA_FILE="pair_wt_${WT}_${STATUS}"
|
| 17 |
+
LOG_LOC=$SCRIPT_LOC
|
| 18 |
+
DATE=$(date +%m_%d)
|
| 19 |
+
SPECIAL_PREFIX="binding_affinity_data_generation"
|
| 20 |
+
|
| 21 |
+
# Create log directory if it doesn't exist
|
| 22 |
+
mkdir -p $LOG_LOC
|
| 23 |
+
|
| 24 |
+
cd $SCRIPT_LOC
|
| 25 |
+
source /vast/projects/pranam/lab/shared/miniconda3/etc/profile.d/conda.sh
|
| 26 |
+
conda activate /vast/projects/pranam/lab/shared/miniconda3/envs/metal
|
| 27 |
+
|
| 28 |
+
python -u binding_affinity_split.py > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}.log" 2>&1
|
| 29 |
+
|
| 30 |
+
echo "Script completed at $(date)"
|
| 31 |
+
conda deactivate
|
training_classifiers/.ipynb_checkpoints/finetune_boost-checkpoint.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# finetune_xgb_halflife_cv_optuna.py
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import math
|
| 7 |
+
import hashlib
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Dict, Any, Optional, Tuple, List
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import optuna
|
| 14 |
+
|
| 15 |
+
from sklearn.model_selection import KFold
|
| 16 |
+
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
| 17 |
+
from scipy.stats import spearmanr
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from transformers import AutoTokenizer, AutoModel
|
| 21 |
+
|
| 22 |
+
import xgboost as xgb
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# -----------------------------
|
| 26 |
+
# Repro
|
| 27 |
+
# -----------------------------
|
| 28 |
+
SEED = 1986
|
| 29 |
+
np.random.seed(SEED)
|
| 30 |
+
torch.manual_seed(SEED)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# -----------------------------
|
| 34 |
+
# Metrics (mirrors your stability script style)
|
| 35 |
+
# -----------------------------
|
| 36 |
+
def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
| 37 |
+
rho = spearmanr(y_true, y_pred).correlation
|
| 38 |
+
if rho is None or np.isnan(rho):
|
| 39 |
+
return 0.0
|
| 40 |
+
return float(rho)
|
| 41 |
+
|
| 42 |
+
def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
|
| 43 |
+
rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
|
| 44 |
+
mae = float(mean_absolute_error(y_true, y_pred))
|
| 45 |
+
r2 = float(r2_score(y_true, y_pred))
|
| 46 |
+
rho = float(safe_spearmanr(y_true, y_pred))
|
| 47 |
+
return {"rmse": rmse, "mae": mae, "r2": r2, "spearman_rho": rho}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# -----------------------------
|
| 51 |
+
# ESM-2 embeddings (cached)
|
| 52 |
+
# -----------------------------
|
| 53 |
+
@dataclass
|
| 54 |
+
class ESMEmbedderConfig:
|
| 55 |
+
model_name: str = "facebook/esm2_t33_650M_UR50D"
|
| 56 |
+
batch_size: int = 8
|
| 57 |
+
max_length: int = 1024 # truncate very long proteins
|
| 58 |
+
fp16: bool = True
|
| 59 |
+
|
| 60 |
+
class ESM2Embedder:
|
| 61 |
+
"""
|
| 62 |
+
Mean-pooled last hidden state (excluding special tokens) -> (H,) per sequence.
|
| 63 |
+
"""
|
| 64 |
+
def __init__(self, cfg: ESMEmbedderConfig, device: str = "cuda"):
|
| 65 |
+
self.cfg = cfg
|
| 66 |
+
self.device = device if (device == "cuda" and torch.cuda.is_available()) else "cpu"
|
| 67 |
+
self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, do_lower_case=False)
|
| 68 |
+
self.model = AutoModel.from_pretrained(cfg.model_name)
|
| 69 |
+
self.model.eval()
|
| 70 |
+
self.model.to(self.device)
|
| 71 |
+
|
| 72 |
+
# Turn off gradients
|
| 73 |
+
for p in self.model.parameters():
|
| 74 |
+
p.requires_grad = False
|
| 75 |
+
|
| 76 |
+
@torch.inference_mode()
|
| 77 |
+
def embed(self, seqs: List[str]) -> np.ndarray:
|
| 78 |
+
out = []
|
| 79 |
+
bs = self.cfg.batch_size
|
| 80 |
+
|
| 81 |
+
use_amp = (self.cfg.fp16 and self.device == "cuda")
|
| 82 |
+
autocast = torch.cuda.amp.autocast if use_amp else torch.cpu.amp.autocast # safe fallback
|
| 83 |
+
|
| 84 |
+
for i in range(0, len(seqs), bs):
|
| 85 |
+
batch = [s.strip().upper() for s in seqs[i:i+bs]]
|
| 86 |
+
toks = self.tokenizer(
|
| 87 |
+
batch,
|
| 88 |
+
return_tensors="pt",
|
| 89 |
+
padding=True,
|
| 90 |
+
truncation=True,
|
| 91 |
+
max_length=self.cfg.max_length,
|
| 92 |
+
add_special_tokens=True,
|
| 93 |
+
)
|
| 94 |
+
toks = {k: v.to(self.device) for k, v in toks.items()}
|
| 95 |
+
attn = toks["attention_mask"] # (B, L)
|
| 96 |
+
|
| 97 |
+
with autocast(enabled=use_amp):
|
| 98 |
+
h = self.model(**toks).last_hidden_state # (B, L, H)
|
| 99 |
+
|
| 100 |
+
# mask out special tokens: first token is <cls>; last non-pad token is usually <eos>
|
| 101 |
+
mask = attn.clone()
|
| 102 |
+
mask[:, 0] = 0
|
| 103 |
+
lengths = attn.sum(dim=1) # includes special tokens
|
| 104 |
+
# zero out last real token position per sequence
|
| 105 |
+
eos_pos = (lengths - 1).clamp(min=0)
|
| 106 |
+
mask[torch.arange(mask.size(0), device=mask.device), eos_pos] = 0
|
| 107 |
+
|
| 108 |
+
denom = mask.sum(dim=1).clamp(min=1).unsqueeze(-1) # (B,1)
|
| 109 |
+
pooled = (h * mask.unsqueeze(-1)).sum(dim=1) / denom # (B,H)
|
| 110 |
+
out.append(pooled.float().detach().cpu().numpy())
|
| 111 |
+
|
| 112 |
+
return np.concatenate(out, axis=0).astype(np.float32)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def dataset_fingerprint(seqs: List[str], y: np.ndarray, extra: str = "") -> str:
|
| 116 |
+
h = hashlib.sha256()
|
| 117 |
+
for s in seqs:
|
| 118 |
+
h.update(s.encode("utf-8"))
|
| 119 |
+
h.update(b"\n")
|
| 120 |
+
h.update(np.asarray(y, dtype=np.float32).tobytes())
|
| 121 |
+
h.update(extra.encode("utf-8"))
|
| 122 |
+
return h.hexdigest()[:16]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def load_or_compute_embeddings(
|
| 126 |
+
df: pd.DataFrame,
|
| 127 |
+
out_dir: str,
|
| 128 |
+
embed_cfg: ESMEmbedderConfig,
|
| 129 |
+
device: str,
|
| 130 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 131 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 132 |
+
|
| 133 |
+
seqs = df["sequence"].astype(str).tolist()
|
| 134 |
+
y = df["half_life_hours"].astype(float).to_numpy(dtype=np.float32)
|
| 135 |
+
|
| 136 |
+
fp = dataset_fingerprint(seqs, y, extra=f"{embed_cfg.model_name}|{embed_cfg.max_length}")
|
| 137 |
+
emb_path = os.path.join(out_dir, f"esm2_embeddings_{fp}.npy")
|
| 138 |
+
meta_path = os.path.join(out_dir, f"esm2_embeddings_{fp}.json")
|
| 139 |
+
|
| 140 |
+
if os.path.exists(emb_path) and os.path.exists(meta_path):
|
| 141 |
+
X = np.load(emb_path).astype(np.float32)
|
| 142 |
+
return X, y, np.asarray(seqs)
|
| 143 |
+
|
| 144 |
+
embedder = ESM2Embedder(embed_cfg, device=device)
|
| 145 |
+
X = embedder.embed(seqs) # (N,H)
|
| 146 |
+
|
| 147 |
+
np.save(emb_path, X)
|
| 148 |
+
with open(meta_path, "w") as f:
|
| 149 |
+
json.dump(
|
| 150 |
+
{
|
| 151 |
+
"fingerprint": fp,
|
| 152 |
+
"model_name": embed_cfg.model_name,
|
| 153 |
+
"max_length": embed_cfg.max_length,
|
| 154 |
+
"n": len(seqs),
|
| 155 |
+
"dim": int(X.shape[1]),
|
| 156 |
+
},
|
| 157 |
+
f,
|
| 158 |
+
indent=2,
|
| 159 |
+
)
|
| 160 |
+
return X, y, np.asarray(seqs)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# -----------------------------
|
| 164 |
+
# XGBoost training (supports "finetune" via xgb_model)
|
| 165 |
+
# -----------------------------
|
| 166 |
+
def train_xgb_reg(
|
| 167 |
+
X_train: np.ndarray,
|
| 168 |
+
y_train: np.ndarray,
|
| 169 |
+
X_val: np.ndarray,
|
| 170 |
+
y_val: np.ndarray,
|
| 171 |
+
params: Dict[str, Any],
|
| 172 |
+
base_model_json: Optional[str] = None,
|
| 173 |
+
) -> Tuple[xgb.Booster, np.ndarray, np.ndarray, int]:
|
| 174 |
+
dtrain = xgb.DMatrix(X_train, label=y_train)
|
| 175 |
+
dval = xgb.DMatrix(X_val, label=y_val)
|
| 176 |
+
|
| 177 |
+
num_boost_round = int(params.pop("num_boost_round"))
|
| 178 |
+
early_stopping_rounds = int(params.pop("early_stopping_rounds"))
|
| 179 |
+
|
| 180 |
+
# Important: load a fresh base model each fold (avoid leakage)
|
| 181 |
+
xgb_model = None
|
| 182 |
+
if base_model_json is not None:
|
| 183 |
+
booster0 = xgb.Booster()
|
| 184 |
+
booster0.load_model(base_model_json)
|
| 185 |
+
xgb_model = booster0
|
| 186 |
+
|
| 187 |
+
booster = xgb.train(
|
| 188 |
+
params=params,
|
| 189 |
+
dtrain=dtrain,
|
| 190 |
+
num_boost_round=num_boost_round,
|
| 191 |
+
evals=[(dval, "val")],
|
| 192 |
+
early_stopping_rounds=early_stopping_rounds,
|
| 193 |
+
verbose_eval=False,
|
| 194 |
+
xgb_model=xgb_model, # <-- "finetune": continue boosting from base model
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
p_train = booster.predict(dtrain)
|
| 198 |
+
p_val = booster.predict(dval)
|
| 199 |
+
best_iter = int(getattr(booster, "best_iteration", num_boost_round - 1))
|
| 200 |
+
return booster, p_train, p_val, best_iter
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# -----------------------------
|
| 204 |
+
# Optuna objective: 5-fold mean Spearman rho
|
| 205 |
+
# -----------------------------
|
| 206 |
+
def make_cv_objective(
|
| 207 |
+
X: np.ndarray,
|
| 208 |
+
y: np.ndarray,
|
| 209 |
+
n_splits: int,
|
| 210 |
+
device: str,
|
| 211 |
+
base_model_json: Optional[str],
|
| 212 |
+
target_transform: str,
|
| 213 |
+
):
|
| 214 |
+
kf = KFold(n_splits=n_splits, shuffle=True, random_state=SEED)
|
| 215 |
+
|
| 216 |
+
# Optional target transform (sometimes helps with heavy-tailed half-life)
|
| 217 |
+
if target_transform == "log1p":
|
| 218 |
+
y_used = np.log1p(np.clip(y, a_min=0.0, a_max=None)).astype(np.float32)
|
| 219 |
+
elif target_transform == "none":
|
| 220 |
+
y_used = y.astype(np.float32)
|
| 221 |
+
else:
|
| 222 |
+
raise ValueError(f"Unknown target_transform: {target_transform}")
|
| 223 |
+
|
| 224 |
+
def objective(trial: optuna.Trial) -> float:
|
| 225 |
+
# Hyperparam ranges patterned after your stability script :contentReference[oaicite:1]{index=1}
|
| 226 |
+
params = {
|
| 227 |
+
"objective": "reg:squarederror",
|
| 228 |
+
"eval_metric": "rmse",
|
| 229 |
+
|
| 230 |
+
"lambda": trial.suggest_float("lambda", 1e-10, 100.0, log=True),
|
| 231 |
+
"alpha": trial.suggest_float("alpha", 1e-10, 100.0, log=True),
|
| 232 |
+
"gamma": trial.suggest_float("gamma", 0.0, 10.0),
|
| 233 |
+
|
| 234 |
+
"max_depth": trial.suggest_int("max_depth", 2, 12),
|
| 235 |
+
"min_child_weight": trial.suggest_float("min_child_weight", 1e-3, 200.0, log=True),
|
| 236 |
+
"subsample": trial.suggest_float("subsample", 0.5, 1.0),
|
| 237 |
+
"colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
|
| 238 |
+
|
| 239 |
+
"learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.2, log=True),
|
| 240 |
+
|
| 241 |
+
"tree_method": "hist",
|
| 242 |
+
"device": "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu",
|
| 243 |
+
}
|
| 244 |
+
params["num_boost_round"] = trial.suggest_int("num_boost_round", 30, 1500)
|
| 245 |
+
params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 10, 150)
|
| 246 |
+
|
| 247 |
+
fold_metrics = []
|
| 248 |
+
fold_best_iters = []
|
| 249 |
+
|
| 250 |
+
for fold, (tr_idx, va_idx) in enumerate(kf.split(X), start=1):
|
| 251 |
+
Xtr, ytr = X[tr_idx], y_used[tr_idx]
|
| 252 |
+
Xva, yva = X[va_idx], y_used[va_idx]
|
| 253 |
+
|
| 254 |
+
_, _, p_va, best_iter = train_xgb_reg(
|
| 255 |
+
Xtr, ytr, Xva, yva, params.copy(),
|
| 256 |
+
base_model_json=base_model_json,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
m = eval_regression(yva, p_va)
|
| 260 |
+
fold_metrics.append(m)
|
| 261 |
+
fold_best_iters.append(best_iter)
|
| 262 |
+
|
| 263 |
+
mean_rho = float(np.mean([m["spearman_rho"] for m in fold_metrics]))
|
| 264 |
+
mean_rmse = float(np.mean([m["rmse"] for m in fold_metrics]))
|
| 265 |
+
mean_mae = float(np.mean([m["mae"] for m in fold_metrics]))
|
| 266 |
+
mean_r2 = float(np.mean([m["r2"] for m in fold_metrics]))
|
| 267 |
+
mean_best_iter = float(np.mean(fold_best_iters))
|
| 268 |
+
|
| 269 |
+
trial.set_user_attr("cv_spearman_rho", mean_rho)
|
| 270 |
+
trial.set_user_attr("cv_rmse", mean_rmse)
|
| 271 |
+
trial.set_user_attr("cv_mae", mean_mae)
|
| 272 |
+
trial.set_user_attr("cv_r2", mean_r2)
|
| 273 |
+
trial.set_user_attr("cv_mean_best_iter", mean_best_iter)
|
| 274 |
+
|
| 275 |
+
# maximize Spearman rho (same as your stability workflow :contentReference[oaicite:2]{index=2})
|
| 276 |
+
return mean_rho
|
| 277 |
+
|
| 278 |
+
return objective
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def refit_and_save(
|
| 282 |
+
X: np.ndarray,
|
| 283 |
+
y: np.ndarray,
|
| 284 |
+
seqs: np.ndarray,
|
| 285 |
+
out_dir: str,
|
| 286 |
+
best_params: Dict[str, Any],
|
| 287 |
+
n_splits: int,
|
| 288 |
+
device: str,
|
| 289 |
+
base_model_json: Optional[str],
|
| 290 |
+
target_transform: str,
|
| 291 |
+
):
|
| 292 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 293 |
+
|
| 294 |
+
# Transform target consistently
|
| 295 |
+
if target_transform == "log1p":
|
| 296 |
+
y_used = np.log1p(np.clip(y, a_min=0.0, a_max=None)).astype(np.float32)
|
| 297 |
+
else:
|
| 298 |
+
y_used = y.astype(np.float32)
|
| 299 |
+
|
| 300 |
+
kf = KFold(n_splits=n_splits, shuffle=True, random_state=SEED)
|
| 301 |
+
|
| 302 |
+
# 1) get OOF preds + average best_iteration
|
| 303 |
+
oof_pred = np.zeros_like(y_used, dtype=np.float32)
|
| 304 |
+
best_iters = []
|
| 305 |
+
fold_rows = []
|
| 306 |
+
|
| 307 |
+
for fold, (tr_idx, va_idx) in enumerate(kf.split(X), start=1):
|
| 308 |
+
Xtr, ytr = X[tr_idx], y_used[tr_idx]
|
| 309 |
+
Xva, yva = X[va_idx], y_used[va_idx]
|
| 310 |
+
|
| 311 |
+
_, _, p_va, best_iter = train_xgb_reg(
|
| 312 |
+
Xtr, ytr, Xva, yva, best_params.copy(),
|
| 313 |
+
base_model_json=base_model_json,
|
| 314 |
+
)
|
| 315 |
+
oof_pred[va_idx] = p_va.astype(np.float32)
|
| 316 |
+
best_iters.append(best_iter)
|
| 317 |
+
|
| 318 |
+
m = eval_regression(yva, p_va)
|
| 319 |
+
fold_rows.append({"fold": fold, **m, "best_iter": int(best_iter)})
|
| 320 |
+
|
| 321 |
+
fold_df = pd.DataFrame(fold_rows)
|
| 322 |
+
fold_df.to_csv(os.path.join(out_dir, "cv_fold_metrics.csv"), index=False)
|
| 323 |
+
|
| 324 |
+
cv_metrics = eval_regression(y_used, oof_pred)
|
| 325 |
+
with open(os.path.join(out_dir, "cv_oof_summary.json"), "w") as f:
|
| 326 |
+
json.dump(cv_metrics, f, indent=2)
|
| 327 |
+
|
| 328 |
+
oof_df = pd.DataFrame({
|
| 329 |
+
"sequence": seqs,
|
| 330 |
+
"y_true_used": y_used.astype(float),
|
| 331 |
+
"y_pred_oof": oof_pred.astype(float),
|
| 332 |
+
"residual": (y_used - oof_pred).astype(float),
|
| 333 |
+
})
|
| 334 |
+
oof_df.to_csv(os.path.join(out_dir, "cv_oof_predictions.csv"), index=False)
|
| 335 |
+
|
| 336 |
+
mean_best_iter = int(round(float(np.mean(best_iters))))
|
| 337 |
+
final_rounds = max(mean_best_iter + 1, 10)
|
| 338 |
+
|
| 339 |
+
# 2) train final model on ALL data (no early stopping here; use final_rounds)
|
| 340 |
+
dtrain_all = xgb.DMatrix(X, label=y_used)
|
| 341 |
+
|
| 342 |
+
xgb_model = None
|
| 343 |
+
if base_model_json is not None:
|
| 344 |
+
booster0 = xgb.Booster()
|
| 345 |
+
booster0.load_model(base_model_json)
|
| 346 |
+
xgb_model = booster0
|
| 347 |
+
|
| 348 |
+
final_params = best_params.copy()
|
| 349 |
+
final_params.pop("early_stopping_rounds", None)
|
| 350 |
+
final_params["device"] = "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu"
|
| 351 |
+
|
| 352 |
+
booster = xgb.train(
|
| 353 |
+
params=final_params,
|
| 354 |
+
dtrain=dtrain_all,
|
| 355 |
+
num_boost_round=int(final_params.pop("num_boost_round", final_rounds)),
|
| 356 |
+
evals=[],
|
| 357 |
+
verbose_eval=False,
|
| 358 |
+
xgb_model=xgb_model,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
model_path = os.path.join(out_dir, "best_model_finetuned.json")
|
| 362 |
+
booster.save_model(model_path)
|
| 363 |
+
|
| 364 |
+
with open(os.path.join(out_dir, "final_training_notes.json"), "w") as f:
|
| 365 |
+
json.dump(
|
| 366 |
+
{
|
| 367 |
+
"target_transform": target_transform,
|
| 368 |
+
"final_rounds_used": int(final_rounds),
|
| 369 |
+
"cv_oof_metrics_on_used_target": cv_metrics,
|
| 370 |
+
"model_path": model_path,
|
| 371 |
+
},
|
| 372 |
+
f,
|
| 373 |
+
indent=2,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
print("=" * 72)
|
| 377 |
+
print("[Final] CV OOF metrics (on transformed target if enabled):")
|
| 378 |
+
print(json.dumps(cv_metrics, indent=2))
|
| 379 |
+
print(f"[Final] Saved finetuned model -> {model_path}")
|
| 380 |
+
print("=" * 72)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def main():
|
| 384 |
+
import argparse
|
| 385 |
+
|
| 386 |
+
parser = argparse.ArgumentParser()
|
| 387 |
+
parser.add_argument("--csv_path", type=str, default="/scratch/pranamlab/tong/data/halflife/wt_halflife_merged_dedup.csv")
|
| 388 |
+
parser.add_argument("--out_dir", type=str, default="/scratch/pranamlab/tong/PeptiVerse/src/halflife/finetune_stability_xgb")
|
| 389 |
+
|
| 390 |
+
# If provided, we will "finetune" by continuing boosting from this model
|
| 391 |
+
parser.add_argument("--base_model_json", type=str, default='/scratch/pranamlab/tong/PeptiVerse/src/stability/xgboost/best_model.json', help="Path to an existing XGBoost .json model to continue training from")
|
| 392 |
+
|
| 393 |
+
# ESM embedding config
|
| 394 |
+
parser.add_argument("--esm_model", type=str, default="facebook/esm2_t33_650M_UR50D")
|
| 395 |
+
parser.add_argument("--esm_batch_size", type=int, default=8)
|
| 396 |
+
parser.add_argument("--esm_max_length", type=int, default=1024)
|
| 397 |
+
parser.add_argument("--no_fp16", action="store_true")
|
| 398 |
+
|
| 399 |
+
# Training config
|
| 400 |
+
parser.add_argument("--n_trials", type=int, default=200)
|
| 401 |
+
parser.add_argument("--n_splits", type=int, default=5)
|
| 402 |
+
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
|
| 403 |
+
parser.add_argument("--target_transform", type=str, default="none", choices=["none", "log1p"])
|
| 404 |
+
|
| 405 |
+
args = parser.parse_args()
|
| 406 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
| 407 |
+
|
| 408 |
+
# Load data
|
| 409 |
+
df = pd.read_csv(args.csv_path)
|
| 410 |
+
if "sequence" not in df.columns or "half_life_hours" not in df.columns:
|
| 411 |
+
raise ValueError("CSV must contain columns: sequence, half_life_hours")
|
| 412 |
+
|
| 413 |
+
df = df.dropna(subset=["sequence", "half_life_hours"]).copy()
|
| 414 |
+
df["sequence"] = df["sequence"].astype(str).str.strip()
|
| 415 |
+
df = df[df["sequence"].str.len() > 0]
|
| 416 |
+
df = df.drop_duplicates(subset=["sequence"], keep="first").reset_index(drop=True)
|
| 417 |
+
|
| 418 |
+
print(f"[Data] N={len(df)} from {args.csv_path}")
|
| 419 |
+
|
| 420 |
+
# Embeddings (cached)
|
| 421 |
+
embed_cfg = ESMEmbedderConfig(
|
| 422 |
+
model_name=args.esm_model,
|
| 423 |
+
batch_size=args.esm_batch_size,
|
| 424 |
+
max_length=args.esm_max_length,
|
| 425 |
+
fp16=(not args.no_fp16),
|
| 426 |
+
)
|
| 427 |
+
X, y, seqs = load_or_compute_embeddings(df, args.out_dir, embed_cfg, device=args.device)
|
| 428 |
+
print(f"[Embeddings] X={X.shape} (float32)")
|
| 429 |
+
|
| 430 |
+
# Optuna study
|
| 431 |
+
sampler = optuna.samplers.TPESampler(seed=SEED)
|
| 432 |
+
study = optuna.create_study(
|
| 433 |
+
direction="maximize", # like your stability script :contentReference[oaicite:3]{index=3}
|
| 434 |
+
sampler=sampler,
|
| 435 |
+
pruner=optuna.pruners.MedianPruner(),
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
objective = make_cv_objective(
|
| 439 |
+
X=X,
|
| 440 |
+
y=y,
|
| 441 |
+
n_splits=args.n_splits,
|
| 442 |
+
device=args.device,
|
| 443 |
+
base_model_json=args.base_model_json,
|
| 444 |
+
target_transform=args.target_transform,
|
| 445 |
+
)
|
| 446 |
+
study.optimize(objective, n_trials=args.n_trials)
|
| 447 |
+
|
| 448 |
+
# Save trials
|
| 449 |
+
trials_df = study.trials_dataframe()
|
| 450 |
+
trials_df.to_csv(os.path.join(args.out_dir, "study_trials.csv"), index=False)
|
| 451 |
+
|
| 452 |
+
best = study.best_trial
|
| 453 |
+
best_params = dict(best.params)
|
| 454 |
+
|
| 455 |
+
# Build full param dict for refit
|
| 456 |
+
best_xgb_params = {
|
| 457 |
+
"objective": "reg:squarederror",
|
| 458 |
+
"eval_metric": "rmse",
|
| 459 |
+
"lambda": best_params["lambda"],
|
| 460 |
+
"alpha": best_params["alpha"],
|
| 461 |
+
"gamma": best_params["gamma"],
|
| 462 |
+
"max_depth": best_params["max_depth"],
|
| 463 |
+
"min_child_weight": best_params["min_child_weight"],
|
| 464 |
+
"subsample": best_params["subsample"],
|
| 465 |
+
"colsample_bytree": best_params["colsample_bytree"],
|
| 466 |
+
"learning_rate": best_params["learning_rate"],
|
| 467 |
+
"tree_method": "hist",
|
| 468 |
+
"device": "cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu",
|
| 469 |
+
"num_boost_round": best_params["num_boost_round"],
|
| 470 |
+
"early_stopping_rounds": best_params["early_stopping_rounds"],
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
# Summary
|
| 474 |
+
summary = {
|
| 475 |
+
"best_trial_number": int(best.number),
|
| 476 |
+
"best_value_cv_spearman_rho": float(best.value),
|
| 477 |
+
"best_user_attrs": best.user_attrs,
|
| 478 |
+
"best_params": best_params,
|
| 479 |
+
"best_xgb_params_full": best_xgb_params,
|
| 480 |
+
"base_model_json": args.base_model_json,
|
| 481 |
+
"target_transform": args.target_transform,
|
| 482 |
+
"esm_model": args.esm_model,
|
| 483 |
+
"esm_max_length": args.esm_max_length,
|
| 484 |
+
}
|
| 485 |
+
with open(os.path.join(args.out_dir, "optimization_summary.json"), "w") as f:
|
| 486 |
+
json.dump(summary, f, indent=2)
|
| 487 |
+
|
| 488 |
+
print("=" * 72)
|
| 489 |
+
print("[Optuna] Best CV Spearman rho:", float(best.value))
|
| 490 |
+
print("[Optuna] Best params:\n", json.dumps(best_params, indent=2))
|
| 491 |
+
print("=" * 72)
|
| 492 |
+
|
| 493 |
+
# Refit + save final finetuned model + OOF predictions
|
| 494 |
+
refit_and_save(
|
| 495 |
+
X=X,
|
| 496 |
+
y=y,
|
| 497 |
+
seqs=seqs,
|
| 498 |
+
out_dir=args.out_dir,
|
| 499 |
+
best_params=best_xgb_params,
|
| 500 |
+
n_splits=args.n_splits,
|
| 501 |
+
device=args.device,
|
| 502 |
+
base_model_json=args.base_model_json,
|
| 503 |
+
target_transform=args.target_transform,
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
if __name__ == "__main__":
|
| 508 |
+
main()
|
training_classifiers/.ipynb_checkpoints/generate_binding_val-checkpoint.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# export_val_preds_csv.py
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
from datasets import load_from_disk, DatasetDict
|
| 11 |
+
|
| 12 |
+
# -----------------------------
|
| 13 |
+
# Repro / device
|
| 14 |
+
# -----------------------------
|
| 15 |
+
def seed_all(seed=1986):
|
| 16 |
+
import random
|
| 17 |
+
random.seed(seed)
|
| 18 |
+
np.random.seed(seed)
|
| 19 |
+
torch.manual_seed(seed)
|
| 20 |
+
torch.cuda.manual_seed_all(seed)
|
| 21 |
+
|
| 22 |
+
seed_all(1986)
|
| 23 |
+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# -----------------------------
|
| 27 |
+
# Load paired DatasetDict
|
| 28 |
+
# -----------------------------
|
| 29 |
+
def load_split_paired(path: str):
|
| 30 |
+
dd = load_from_disk(path)
|
| 31 |
+
if not isinstance(dd, DatasetDict):
|
| 32 |
+
raise ValueError(f"Expected DatasetDict at {path}")
|
| 33 |
+
if "train" not in dd or "val" not in dd:
|
| 34 |
+
raise ValueError(f"DatasetDict missing train/val at {path}")
|
| 35 |
+
return dd["train"], dd["val"]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# -----------------------------
|
| 39 |
+
# Collate fns (same as yours)
|
| 40 |
+
# -----------------------------
|
| 41 |
+
def collate_pair_pooled(batch):
|
| 42 |
+
Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32)
|
| 43 |
+
Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32)
|
| 44 |
+
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
|
| 45 |
+
return Pt, Pb, y
|
| 46 |
+
|
| 47 |
+
def collate_pair_unpooled(batch):
|
| 48 |
+
B = len(batch)
|
| 49 |
+
Ht = len(batch[0]["target_embedding"][0])
|
| 50 |
+
Hb = len(batch[0]["binder_embedding"][0])
|
| 51 |
+
Lt_max = max(int(x["target_length"]) for x in batch)
|
| 52 |
+
Lb_max = max(int(x["binder_length"]) for x in batch)
|
| 53 |
+
|
| 54 |
+
Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32)
|
| 55 |
+
Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32)
|
| 56 |
+
Mt = torch.zeros(B, Lt_max, dtype=torch.bool)
|
| 57 |
+
Mb = torch.zeros(B, Lb_max, dtype=torch.bool)
|
| 58 |
+
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
|
| 59 |
+
|
| 60 |
+
for i, x in enumerate(batch):
|
| 61 |
+
t = torch.tensor(x["target_embedding"], dtype=torch.float32)
|
| 62 |
+
b = torch.tensor(x["binder_embedding"], dtype=torch.float32)
|
| 63 |
+
lt, lb = t.shape[0], b.shape[0]
|
| 64 |
+
Pt[i, :lt] = t
|
| 65 |
+
Pb[i, :lb] = b
|
| 66 |
+
Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool)
|
| 67 |
+
Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool)
|
| 68 |
+
|
| 69 |
+
return Pt, Mt, Pb, Mb, y
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# -----------------------------
|
| 73 |
+
# Models (same as yours)
|
| 74 |
+
# -----------------------------
|
| 75 |
+
class CrossAttnPooled(nn.Module):
|
| 76 |
+
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
|
| 77 |
+
super().__init__()
|
| 78 |
+
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
|
| 79 |
+
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
|
| 80 |
+
|
| 81 |
+
self.layers = nn.ModuleList([])
|
| 82 |
+
for _ in range(n_layers):
|
| 83 |
+
self.layers.append(nn.ModuleDict({
|
| 84 |
+
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
|
| 85 |
+
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
|
| 86 |
+
"n1t": nn.LayerNorm(hidden),
|
| 87 |
+
"n2t": nn.LayerNorm(hidden),
|
| 88 |
+
"n1b": nn.LayerNorm(hidden),
|
| 89 |
+
"n2b": nn.LayerNorm(hidden),
|
| 90 |
+
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 91 |
+
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 92 |
+
}))
|
| 93 |
+
|
| 94 |
+
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
|
| 95 |
+
self.reg = nn.Linear(hidden, 1)
|
| 96 |
+
self.cls = nn.Linear(hidden, 3)
|
| 97 |
+
|
| 98 |
+
def forward(self, t_vec, b_vec):
|
| 99 |
+
t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
|
| 100 |
+
b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
|
| 101 |
+
|
| 102 |
+
for L in self.layers:
|
| 103 |
+
t_attn, _ = L["attn_tb"](t, b, b)
|
| 104 |
+
t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
|
| 105 |
+
t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
|
| 106 |
+
|
| 107 |
+
b_attn, _ = L["attn_bt"](b, t, t)
|
| 108 |
+
b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
|
| 109 |
+
b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
|
| 110 |
+
|
| 111 |
+
z = torch.cat([t[0], b[0]], dim=-1)
|
| 112 |
+
h = self.shared(z)
|
| 113 |
+
return self.reg(h).squeeze(-1), self.cls(h)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class CrossAttnUnpooled(nn.Module):
|
| 117 |
+
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
|
| 120 |
+
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
|
| 121 |
+
|
| 122 |
+
self.layers = nn.ModuleList([])
|
| 123 |
+
for _ in range(n_layers):
|
| 124 |
+
self.layers.append(nn.ModuleDict({
|
| 125 |
+
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 126 |
+
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 127 |
+
"n1t": nn.LayerNorm(hidden),
|
| 128 |
+
"n2t": nn.LayerNorm(hidden),
|
| 129 |
+
"n1b": nn.LayerNorm(hidden),
|
| 130 |
+
"n2b": nn.LayerNorm(hidden),
|
| 131 |
+
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 132 |
+
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 133 |
+
}))
|
| 134 |
+
|
| 135 |
+
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
|
| 136 |
+
self.reg = nn.Linear(hidden, 1)
|
| 137 |
+
self.cls = nn.Linear(hidden, 3)
|
| 138 |
+
|
| 139 |
+
def masked_mean(self, X, M):
|
| 140 |
+
Mf = M.unsqueeze(-1).float()
|
| 141 |
+
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 142 |
+
return (X * Mf).sum(dim=1) / denom
|
| 143 |
+
|
| 144 |
+
def forward(self, T, Mt, B, Mb):
|
| 145 |
+
T = self.t_proj(T)
|
| 146 |
+
Bx = self.b_proj(B)
|
| 147 |
+
|
| 148 |
+
kp_t = ~Mt
|
| 149 |
+
kp_b = ~Mb
|
| 150 |
+
|
| 151 |
+
for L in self.layers:
|
| 152 |
+
T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
|
| 153 |
+
T = L["n1t"](T + T_attn)
|
| 154 |
+
T = L["n2t"](T + L["fft"](T))
|
| 155 |
+
|
| 156 |
+
B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
|
| 157 |
+
Bx = L["n1b"](Bx + B_attn)
|
| 158 |
+
Bx = L["n2b"](Bx + L["ffb"](Bx))
|
| 159 |
+
|
| 160 |
+
t_pool = self.masked_mean(T, Mt)
|
| 161 |
+
b_pool = self.masked_mean(Bx, Mb)
|
| 162 |
+
z = torch.cat([t_pool, b_pool], dim=-1)
|
| 163 |
+
h = self.shared(z)
|
| 164 |
+
return self.reg(h).squeeze(-1), self.cls(h)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# -----------------------------
|
| 168 |
+
# Helpers
|
| 169 |
+
# -----------------------------
|
| 170 |
+
def softmax_np(logits: np.ndarray) -> np.ndarray:
|
| 171 |
+
x = logits - logits.max(axis=1, keepdims=True)
|
| 172 |
+
ex = np.exp(x)
|
| 173 |
+
return ex / ex.sum(axis=1, keepdims=True)
|
| 174 |
+
|
| 175 |
+
def expected_score_from_probs(probs: np.ndarray, class_centers=(9.5, 8.0, 6.0)) -> np.ndarray:
|
| 176 |
+
centers = np.asarray(class_centers, dtype=np.float32)[None, :] # (1,3)
|
| 177 |
+
return (probs * centers).sum(axis=1)
|
| 178 |
+
|
| 179 |
+
def load_checkpoint(ckpt_path: str, mode: str, train_ds):
|
| 180 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 181 |
+
params = ckpt.get("best_params", {})
|
| 182 |
+
|
| 183 |
+
hidden = int(params.get("hidden_dim", 512))
|
| 184 |
+
n_heads = int(params.get("n_heads", 8))
|
| 185 |
+
n_layers = int(params.get("n_layers", 3))
|
| 186 |
+
dropout = float(params.get("dropout", 0.1))
|
| 187 |
+
|
| 188 |
+
if mode == "pooled":
|
| 189 |
+
Ht = len(train_ds[0]["target_embedding"])
|
| 190 |
+
Hb = len(train_ds[0]["binder_embedding"])
|
| 191 |
+
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout)
|
| 192 |
+
else:
|
| 193 |
+
Ht = len(train_ds[0]["target_embedding"][0])
|
| 194 |
+
Hb = len(train_ds[0]["binder_embedding"][0])
|
| 195 |
+
model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout)
|
| 196 |
+
|
| 197 |
+
model.load_state_dict(ckpt["state_dict"], strict=True)
|
| 198 |
+
model.to(DEVICE).eval()
|
| 199 |
+
return model
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@torch.no_grad()
|
| 203 |
+
def export_val_preds_csv(dataset_path: str, ckpt_path: str, mode: str,
|
| 204 |
+
out_csv: str, batch_size: int, num_workers: int,
|
| 205 |
+
class_centers=(9.5, 8.0, 6.0)):
|
| 206 |
+
train_ds, val_ds = load_split_paired(dataset_path)
|
| 207 |
+
model = load_checkpoint(ckpt_path, mode, train_ds)
|
| 208 |
+
|
| 209 |
+
if mode == "pooled":
|
| 210 |
+
loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
|
| 211 |
+
num_workers=num_workers, pin_memory=True,
|
| 212 |
+
collate_fn=collate_pair_pooled)
|
| 213 |
+
y_all, pred_reg_all, logits_all = [], [], []
|
| 214 |
+
for t, b, y in loader:
|
| 215 |
+
t = t.to(DEVICE, non_blocking=True)
|
| 216 |
+
b = b.to(DEVICE, non_blocking=True)
|
| 217 |
+
pred_reg, logits = model(t, b)
|
| 218 |
+
y_all.append(y.numpy())
|
| 219 |
+
pred_reg_all.append(pred_reg.detach().cpu().numpy())
|
| 220 |
+
logits_all.append(logits.detach().cpu().numpy())
|
| 221 |
+
|
| 222 |
+
else:
|
| 223 |
+
loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
|
| 224 |
+
num_workers=num_workers, pin_memory=True,
|
| 225 |
+
collate_fn=collate_pair_unpooled)
|
| 226 |
+
y_all, pred_reg_all, logits_all = [], [], []
|
| 227 |
+
for T, Mt, B, Mb, y in loader:
|
| 228 |
+
T = T.to(DEVICE, non_blocking=True)
|
| 229 |
+
Mt = Mt.to(DEVICE, non_blocking=True)
|
| 230 |
+
B = B.to(DEVICE, non_blocking=True)
|
| 231 |
+
Mb = Mb.to(DEVICE, non_blocking=True)
|
| 232 |
+
pred_reg, logits = model(T, Mt, B, Mb)
|
| 233 |
+
y_all.append(y.numpy())
|
| 234 |
+
pred_reg_all.append(pred_reg.detach().cpu().numpy())
|
| 235 |
+
logits_all.append(logits.detach().cpu().numpy())
|
| 236 |
+
|
| 237 |
+
y_true = np.concatenate(y_all)
|
| 238 |
+
y_pred_reg = np.concatenate(pred_reg_all)
|
| 239 |
+
logits = np.concatenate(logits_all)
|
| 240 |
+
|
| 241 |
+
probs = softmax_np(logits) # (N,3)
|
| 242 |
+
y_pred_cls_score = expected_score_from_probs(probs, class_centers=class_centers)
|
| 243 |
+
|
| 244 |
+
# Build CSV rows
|
| 245 |
+
out = Path(out_csv)
|
| 246 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 247 |
+
|
| 248 |
+
header = [
|
| 249 |
+
"split", "mode",
|
| 250 |
+
"y_true",
|
| 251 |
+
"y_pred_reg",
|
| 252 |
+
"p_high", "p_moderate", "p_low",
|
| 253 |
+
"y_pred_cls_score",
|
| 254 |
+
"center_high", "center_moderate", "center_low",
|
| 255 |
+
]
|
| 256 |
+
|
| 257 |
+
centers = list(class_centers)
|
| 258 |
+
rows = np.column_stack([
|
| 259 |
+
y_true,
|
| 260 |
+
y_pred_reg,
|
| 261 |
+
probs[:, 0], probs[:, 1], probs[:, 2],
|
| 262 |
+
y_pred_cls_score,
|
| 263 |
+
np.full_like(y_true, centers[0], dtype=np.float32),
|
| 264 |
+
np.full_like(y_true, centers[1], dtype=np.float32),
|
| 265 |
+
np.full_like(y_true, centers[2], dtype=np.float32),
|
| 266 |
+
])
|
| 267 |
+
|
| 268 |
+
with out.open("w") as f:
|
| 269 |
+
f.write(",".join(header) + "\n")
|
| 270 |
+
for i in range(rows.shape[0]):
|
| 271 |
+
f.write(
|
| 272 |
+
"val," + mode + "," +
|
| 273 |
+
",".join(f"{rows[i, j]:.8f}" for j in range(rows.shape[1])) +
|
| 274 |
+
"\n"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
print(f"[Data] Val N={len(y_true)} | mode={mode}")
|
| 278 |
+
print(f"[Saved] {out}")
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def main():
|
| 282 |
+
ap = argparse.ArgumentParser()
|
| 283 |
+
ap.add_argument("--dataset_path", required=True, help="Paired DatasetDict path (pair_*)")
|
| 284 |
+
ap.add_argument("--ckpt", required=True, help="Path to best_model.pt")
|
| 285 |
+
ap.add_argument("--mode", choices=["pooled", "unpooled"], required=True)
|
| 286 |
+
ap.add_argument("--out_csv", required=True)
|
| 287 |
+
ap.add_argument("--batch_size", type=int, default=128)
|
| 288 |
+
ap.add_argument("--num_workers", type=int, default=4)
|
| 289 |
+
|
| 290 |
+
# Optional: choose class-centers for expected-score conversion
|
| 291 |
+
ap.add_argument("--center_high", type=float, default=9.5)
|
| 292 |
+
ap.add_argument("--center_moderate", type=float, default=8.0)
|
| 293 |
+
ap.add_argument("--center_low", type=float, default=6.0)
|
| 294 |
+
|
| 295 |
+
args = ap.parse_args()
|
| 296 |
+
|
| 297 |
+
export_val_preds_csv(
|
| 298 |
+
dataset_path=args.dataset_path,
|
| 299 |
+
ckpt_path=args.ckpt,
|
| 300 |
+
mode=args.mode,
|
| 301 |
+
out_csv=args.out_csv,
|
| 302 |
+
batch_size=args.batch_size,
|
| 303 |
+
num_workers=args.num_workers,
|
| 304 |
+
class_centers=(args.center_high, args.center_moderate, args.center_low),
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
main()
|
training_classifiers/.ipynb_checkpoints/peptiverse_filelist-checkpoint.txt
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
./hemolysis/cnn_smiles/optimization_summary.txt
|
| 2 |
+
./hemolysis/cnn_smiles/pr_curve.png
|
| 3 |
+
./hemolysis/cnn_smiles/roc_curve.png
|
| 4 |
+
./hemolysis/cnn_smiles/study_trials.csv
|
| 5 |
+
./hemolysis/cnn_smiles/train_predictions.csv
|
| 6 |
+
./hemolysis/cnn_smiles/val_predictions.csv
|
| 7 |
+
./hemolysis/cnn_wt/optimization_summary.txt
|
| 8 |
+
./hemolysis/cnn_wt/pr_curve.png
|
| 9 |
+
./hemolysis/cnn_wt/roc_curve.png
|
| 10 |
+
./hemolysis/cnn_wt/study_trials.csv
|
| 11 |
+
./hemolysis/cnn_wt/train_predictions.csv
|
| 12 |
+
./hemolysis/cnn_wt/val_predictions.csv
|
| 13 |
+
./hemolysis/enet_gpu/optimization_summary.txt
|
| 14 |
+
./hemolysis/enet_gpu/pr_curve.png
|
| 15 |
+
./hemolysis/enet_gpu/roc_curve.png
|
| 16 |
+
./hemolysis/enet_gpu/study_trials.csv
|
| 17 |
+
./hemolysis/enet_gpu/train_predictions.csv
|
| 18 |
+
./hemolysis/enet_gpu/val_predictions.csv
|
| 19 |
+
./hemolysis/enet_gpu_smiles/optimization_summary.txt
|
| 20 |
+
./hemolysis/enet_gpu_smiles/pr_curve.png
|
| 21 |
+
./hemolysis/enet_gpu_smiles/roc_curve.png
|
| 22 |
+
./hemolysis/enet_gpu_smiles/study_trials.csv
|
| 23 |
+
./hemolysis/enet_gpu_smiles/train_predictions.csv
|
| 24 |
+
./hemolysis/enet_gpu_smiles/val_predictions.csv
|
| 25 |
+
./hemolysis/enet_gpu_wt/optimization_summary.txt
|
| 26 |
+
./hemolysis/enet_gpu_wt/pr_curve.png
|
| 27 |
+
./hemolysis/enet_gpu_wt/roc_curve.png
|
| 28 |
+
./hemolysis/enet_gpu_wt/study_trials.csv
|
| 29 |
+
./hemolysis/enet_gpu_wt/train_predictions.csv
|
| 30 |
+
./hemolysis/enet_gpu_wt/val_predictions.csv
|
| 31 |
+
./hemolysis/mlp_smiles/optimization_summary.txt
|
| 32 |
+
./hemolysis/mlp_smiles/pr_curve.png
|
| 33 |
+
./hemolysis/mlp_smiles/roc_curve.png
|
| 34 |
+
./hemolysis/mlp_smiles/study_trials.csv
|
| 35 |
+
./hemolysis/mlp_smiles/train_predictions.csv
|
| 36 |
+
./hemolysis/mlp_smiles/val_predictions.csv
|
| 37 |
+
./hemolysis/mlp_wt/optimization_summary.txt
|
| 38 |
+
./hemolysis/mlp_wt/pr_curve.png
|
| 39 |
+
./hemolysis/mlp_wt/roc_curve.png
|
| 40 |
+
./hemolysis/mlp_wt/study_trials.csv
|
| 41 |
+
./hemolysis/mlp_wt/train_predictions.csv
|
| 42 |
+
./hemolysis/mlp_wt/val_predictions.csv
|
| 43 |
+
./hemolysis/svm_gpu_wt/optimization_summary.txt
|
| 44 |
+
./hemolysis/svm_gpu_wt/pr_curve.png
|
| 45 |
+
./hemolysis/svm_gpu_wt/roc_curve.png
|
| 46 |
+
./hemolysis/svm_gpu_wt/study_trials.csv
|
| 47 |
+
./hemolysis/svm_gpu_wt/train_predictions.csv
|
| 48 |
+
./hemolysis/svm_gpu_wt/val_predictions.csv
|
| 49 |
+
./hemolysis/transformer_smiles/optimization_summary.txt
|
| 50 |
+
./hemolysis/transformer_smiles/pr_curve.png
|
| 51 |
+
./hemolysis/transformer_smiles/roc_curve.png
|
| 52 |
+
./hemolysis/transformer_smiles/study_trials.csv
|
| 53 |
+
./hemolysis/transformer_smiles/train_predictions.csv
|
| 54 |
+
./hemolysis/transformer_smiles/val_predictions.csv
|
| 55 |
+
./hemolysis/transformer_wt/optimization_summary.txt
|
| 56 |
+
./hemolysis/transformer_wt/pr_curve.png
|
| 57 |
+
./hemolysis/transformer_wt/roc_curve.png
|
| 58 |
+
./hemolysis/transformer_wt/study_trials.csv
|
| 59 |
+
./hemolysis/transformer_wt/train_predictions.csv
|
| 60 |
+
./hemolysis/transformer_wt/val_predictions.csv
|
| 61 |
+
./hemolysis/xgb/optimization_summary.txt
|
| 62 |
+
./hemolysis/xgb/pr_curve.png
|
| 63 |
+
./hemolysis/xgb/roc_curve.png
|
| 64 |
+
./hemolysis/xgb/study_trials.csv
|
| 65 |
+
./hemolysis/xgb/train_predictions.csv
|
| 66 |
+
./hemolysis/xgb/val_predictions.csv
|
| 67 |
+
./hemolysis/xgb_smiles/optimization_summary.txt
|
| 68 |
+
./hemolysis/xgb_smiles/pr_curve.png
|
| 69 |
+
./hemolysis/xgb_smiles/roc_curve.png
|
| 70 |
+
./hemolysis/xgb_smiles/study_trials.csv
|
| 71 |
+
./hemolysis/xgb_smiles/train_predictions.csv
|
| 72 |
+
./hemolysis/xgb_smiles/val_predictions.csv
|
| 73 |
+
./hemolysis/xgb_wt/optimization_summary.txt
|
| 74 |
+
./hemolysis/xgb_wt/pr_curve.png
|
| 75 |
+
./hemolysis/xgb_wt/roc_curve.png
|
| 76 |
+
./hemolysis/xgb_wt/study_trials.csv
|
| 77 |
+
./hemolysis/xgb_wt/train_predictions.csv
|
| 78 |
+
./hemolysis/xgb_wt/val_predictions.csv
|
| 79 |
+
./nf/cnn/optimization_summary.txt
|
| 80 |
+
./nf/cnn/pr_curve.png
|
| 81 |
+
./nf/cnn/roc_curve.png
|
| 82 |
+
./nf/cnn/study_trials.csv
|
| 83 |
+
./nf/cnn/train_predictions.csv
|
| 84 |
+
./nf/cnn/val_predictions.csv
|
| 85 |
+
./nf/cnn_wt/optimization_summary.txt
|
| 86 |
+
./nf/cnn_wt/pr_curve.png
|
| 87 |
+
./nf/cnn_wt/roc_curve.png
|
| 88 |
+
./nf/cnn_wt/study_trials.csv
|
| 89 |
+
./nf/cnn_wt/train_predictions.csv
|
| 90 |
+
./nf/cnn_wt/val_predictions.csv
|
| 91 |
+
./nf/enet_gpu/optimization_summary.txt
|
| 92 |
+
./nf/enet_gpu/pr_curve.png
|
| 93 |
+
./nf/enet_gpu/roc_curve.png
|
| 94 |
+
./nf/enet_gpu/study_trials.csv
|
| 95 |
+
./nf/enet_gpu/train_predictions.csv
|
| 96 |
+
./nf/enet_gpu/val_predictions.csv
|
| 97 |
+
./nf/enet_gpu_smiles/optimization_summary.txt
|
| 98 |
+
./nf/enet_gpu_smiles/pr_curve.png
|
| 99 |
+
./nf/enet_gpu_smiles/roc_curve.png
|
| 100 |
+
./nf/enet_gpu_smiles/study_trials.csv
|
| 101 |
+
./nf/enet_gpu_smiles/train_predictions.csv
|
| 102 |
+
./nf/enet_gpu_smiles/val_predictions.csv
|
| 103 |
+
./nf/enet_gpu_wt/optimization_summary.txt
|
| 104 |
+
./nf/enet_gpu_wt/pr_curve.png
|
| 105 |
+
./nf/enet_gpu_wt/roc_curve.png
|
| 106 |
+
./nf/enet_gpu_wt/study_trials.csv
|
| 107 |
+
./nf/enet_gpu_wt/train_predictions.csv
|
| 108 |
+
./nf/enet_gpu_wt/val_predictions.csv
|
| 109 |
+
./nf/mlp/optimization_summary.txt
|
| 110 |
+
./nf/mlp/pr_curve.png
|
| 111 |
+
./nf/mlp/roc_curve.png
|
| 112 |
+
./nf/mlp/study_trials.csv
|
| 113 |
+
./nf/mlp/train_predictions.csv
|
| 114 |
+
./nf/mlp/val_predictions.csv
|
| 115 |
+
./nf/mlp_wt/optimization_summary.txt
|
| 116 |
+
./nf/mlp_wt/pr_curve.png
|
| 117 |
+
./nf/mlp_wt/roc_curve.png
|
| 118 |
+
./nf/mlp_wt/study_trials.csv
|
| 119 |
+
./nf/mlp_wt/train_predictions.csv
|
| 120 |
+
./nf/mlp_wt/val_predictions.csv
|
| 121 |
+
./nf/svm_gpu/optimization_summary.txt
|
| 122 |
+
./nf/svm_gpu/pr_curve.png
|
| 123 |
+
./nf/svm_gpu/roc_curve.png
|
| 124 |
+
./nf/svm_gpu/study_trials.csv
|
| 125 |
+
./nf/svm_gpu/train_predictions.csv
|
| 126 |
+
./nf/svm_gpu/val_predictions.csv
|
| 127 |
+
./nf/svm_gpu_wt/optimization_summary.txt
|
| 128 |
+
./nf/svm_gpu_wt/pr_curve.png
|
| 129 |
+
./nf/svm_gpu_wt/roc_curve.png
|
| 130 |
+
./nf/svm_gpu_wt/study_trials.csv
|
| 131 |
+
./nf/svm_gpu_wt/train_predictions.csv
|
| 132 |
+
./nf/svm_gpu_wt/val_predictions.csv
|
| 133 |
+
./nf/transformer/optimization_summary.txt
|
| 134 |
+
./nf/transformer/pr_curve.png
|
| 135 |
+
./nf/transformer/roc_curve.png
|
| 136 |
+
./nf/transformer/study_trials.csv
|
| 137 |
+
./nf/transformer/train_predictions.csv
|
| 138 |
+
./nf/transformer/val_predictions.csv
|
| 139 |
+
./nf/transformer_wt/optimization_summary.txt
|
| 140 |
+
./nf/transformer_wt/pr_curve.png
|
| 141 |
+
./nf/transformer_wt/roc_curve.png
|
| 142 |
+
./nf/transformer_wt/study_trials.csv
|
| 143 |
+
./nf/transformer_wt/train_predictions.csv
|
| 144 |
+
./nf/transformer_wt/val_predictions.csv
|
| 145 |
+
./nf/xgb_wt/optimization_summary.txt
|
| 146 |
+
./nf/xgb_wt/pr_curve.png
|
| 147 |
+
./nf/xgb_wt/roc_curve.png
|
| 148 |
+
./nf/xgb_wt/study_trials.csv
|
| 149 |
+
./nf/xgb_wt/train_predictions.csv
|
| 150 |
+
./nf/xgb_wt/val_predictions.csv
|
| 151 |
+
./permeability_caco2/cnn_smiles/optimization_summary.txt
|
| 152 |
+
./permeability_caco2/cnn_smiles/study_trials.csv
|
| 153 |
+
./permeability_caco2/cnn_smiles/train_predictions.csv
|
| 154 |
+
./permeability_caco2/cnn_smiles/val_predictions.csv
|
| 155 |
+
./permeability_caco2/enet_gpu_smiles/optimization_summary.txt
|
| 156 |
+
./permeability_caco2/enet_gpu_smiles/study_trials.csv
|
| 157 |
+
./permeability_caco2/enet_gpu_smiles/train_predictions.csv
|
| 158 |
+
./permeability_caco2/enet_gpu_smiles/val_predictions.csv
|
| 159 |
+
./permeability_caco2/mlp_smiles/optimization_summary.txt
|
| 160 |
+
./permeability_caco2/mlp_smiles/study_trials.csv
|
| 161 |
+
./permeability_caco2/mlp_smiles/train_predictions.csv
|
| 162 |
+
./permeability_caco2/mlp_smiles/val_predictions.csv
|
| 163 |
+
./permeability_caco2/svr_smiles/optimization_summary.txt
|
| 164 |
+
./permeability_caco2/svr_smiles/study_trials.csv
|
| 165 |
+
./permeability_caco2/svr_smiles/train_predictions.csv
|
| 166 |
+
./permeability_caco2/svr_smiles/val_predictions.csv
|
| 167 |
+
./permeability_caco2/transformer_smiles/optimization_summary.txt
|
| 168 |
+
./permeability_caco2/transformer_smiles/study_trials.csv
|
| 169 |
+
./permeability_caco2/transformer_smiles/train_predictions.csv
|
| 170 |
+
./permeability_caco2/transformer_smiles/val_predictions.csv
|
| 171 |
+
./permeability_caco2/xgb_reg_smiles/optimization_summary.txt
|
| 172 |
+
./permeability_caco2/xgb_reg_smiles/study_trials.csv
|
| 173 |
+
./permeability_caco2/xgb_reg_smiles/train_predictions.csv
|
| 174 |
+
./permeability_caco2/xgb_reg_smiles/val_predictions.csv
|
| 175 |
+
./permeability_pampa/cnn_smiles/optimization_summary.txt
|
| 176 |
+
./permeability_pampa/cnn_smiles/study_trials.csv
|
| 177 |
+
./permeability_pampa/cnn_smiles/train_predictions.csv
|
| 178 |
+
./permeability_pampa/cnn_smiles/val_predictions.csv
|
| 179 |
+
./permeability_pampa/enet_gpu_smiles/optimization_summary.txt
|
| 180 |
+
./permeability_pampa/enet_gpu_smiles/study_trials.csv
|
| 181 |
+
./permeability_pampa/enet_gpu_smiles/train_predictions.csv
|
| 182 |
+
./permeability_pampa/enet_gpu_smiles/val_predictions.csv
|
| 183 |
+
./permeability_pampa/mlp_smiles/optimization_summary.txt
|
| 184 |
+
./permeability_pampa/mlp_smiles/study_trials.csv
|
| 185 |
+
./permeability_pampa/mlp_smiles/train_predictions.csv
|
| 186 |
+
./permeability_pampa/mlp_smiles/val_predictions.csv
|
| 187 |
+
./permeability_pampa/transformer_smiles/optimization_summary.txt
|
| 188 |
+
./permeability_pampa/transformer_smiles/study_trials.csv
|
| 189 |
+
./permeability_pampa/transformer_smiles/train_predictions.csv
|
| 190 |
+
./permeability_pampa/transformer_smiles/val_predictions.csv
|
| 191 |
+
./permeability_pampa/xgb_reg_smiles/optimization_summary.txt
|
| 192 |
+
./permeability_pampa/xgb_reg_smiles/study_trials.csv
|
| 193 |
+
./permeability_pampa/xgb_reg_smiles/train_predictions.csv
|
| 194 |
+
./permeability_pampa/xgb_reg_smiles/val_predictions.csv
|
| 195 |
+
./solubility/cnn_wt/optimization_summary.txt
|
| 196 |
+
./solubility/cnn_wt/pr_curve.png
|
| 197 |
+
./solubility/cnn_wt/roc_curve.png
|
| 198 |
+
./solubility/cnn_wt/study_trials.csv
|
| 199 |
+
./solubility/cnn_wt/train_predictions.csv
|
| 200 |
+
./solubility/cnn_wt/val_predictions.csv
|
| 201 |
+
./solubility/enet_gpu/optimization_summary.txt
|
| 202 |
+
./solubility/enet_gpu/pr_curve.png
|
| 203 |
+
./solubility/enet_gpu/roc_curve.png
|
| 204 |
+
./solubility/enet_gpu/study_trials.csv
|
| 205 |
+
./solubility/enet_gpu/train_predictions.csv
|
| 206 |
+
./solubility/enet_gpu/val_predictions.csv
|
| 207 |
+
./solubility/mlp_wt/optimization_summary.txt
|
| 208 |
+
./solubility/mlp_wt/pr_curve.png
|
| 209 |
+
./solubility/mlp_wt/roc_curve.png
|
| 210 |
+
./solubility/mlp_wt/study_trials.csv
|
| 211 |
+
./solubility/mlp_wt/train_predictions.csv
|
| 212 |
+
./solubility/mlp_wt/val_predictions.csv
|
| 213 |
+
./solubility/svm_gpu/optimization_summary.txt
|
| 214 |
+
./solubility/svm_gpu/pr_curve.png
|
| 215 |
+
./solubility/svm_gpu/roc_curve.png
|
| 216 |
+
./solubility/svm_gpu/study_trials.csv
|
| 217 |
+
./solubility/svm_gpu/train_predictions.csv
|
| 218 |
+
./solubility/svm_gpu/val_predictions.csv
|
| 219 |
+
./solubility/transformer_wt/optimization_summary.txt
|
| 220 |
+
./solubility/transformer_wt/pr_curve.png
|
| 221 |
+
./solubility/transformer_wt/roc_curve.png
|
| 222 |
+
./solubility/transformer_wt/study_trials.csv
|
| 223 |
+
./solubility/transformer_wt/train_predictions.csv
|
| 224 |
+
./solubility/transformer_wt/val_predictions.csv
|
| 225 |
+
./solubility/xgb/optimization_summary.txt
|
| 226 |
+
./solubility/xgb/pr_curve.png
|
| 227 |
+
./solubility/xgb/roc_curve.png
|
| 228 |
+
./solubility/xgb/study_trials.csv
|
| 229 |
+
./solubility/xgb/train_predictions.csv
|
| 230 |
+
./solubility/xgb/val_predictions.csv
|
| 231 |
+
./binding_affinity/wt_wt_pooled/optuna_trials.csv
|
| 232 |
+
./binding_affinity/wt_smiles_pooled/optuna_trials.csv
|
| 233 |
+
./binding_affinity/wt_smiles_unpooled/optuna_trials.csv
|
| 234 |
+
./binding_affinity/wt_wt_unpooled/optuna_trials.csv
|
training_classifiers/.ipynb_checkpoints/train_boost-checkpoint.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import joblib
|
| 4 |
+
import optuna
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Dict, Any, Tuple, Optional
|
| 11 |
+
|
| 12 |
+
from datasets import load_from_disk, DatasetDict
|
| 13 |
+
from sklearn.metrics import (
|
| 14 |
+
f1_score, roc_auc_score, average_precision_score,
|
| 15 |
+
precision_recall_curve, roc_curve
|
| 16 |
+
)
|
| 17 |
+
from sklearn.linear_model import LogisticRegression
|
| 18 |
+
from sklearn.ensemble import AdaBoostClassifier
|
| 19 |
+
from sklearn.tree import DecisionTreeClassifier
|
| 20 |
+
from linearboost import LinearBoostClassifier
|
| 21 |
+
|
| 22 |
+
import xgboost as xgb
|
| 23 |
+
from lightning.pytorch import seed_everything
|
| 24 |
+
|
| 25 |
+
seed_everything(1986)
|
| 26 |
+
|
| 27 |
+
# -----------------------------
|
| 28 |
+
# Data loading
|
| 29 |
+
# -----------------------------
|
| 30 |
+
@dataclass
|
| 31 |
+
class SplitData:
|
| 32 |
+
X_train: np.ndarray
|
| 33 |
+
y_train: np.ndarray
|
| 34 |
+
seq_train: Optional[np.ndarray]
|
| 35 |
+
X_val: np.ndarray
|
| 36 |
+
y_val: np.ndarray
|
| 37 |
+
seq_val: Optional[np.ndarray]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _stack_embeddings(col) -> np.ndarray:
|
| 41 |
+
# HF datasets often store embeddings as list-of-floats per row
|
| 42 |
+
arr = np.asarray(col, dtype=np.float32)
|
| 43 |
+
if arr.ndim != 2:
|
| 44 |
+
arr = np.stack(col).astype(np.float32)
|
| 45 |
+
return arr
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def load_split_data(dataset_path: str) -> SplitData:
|
| 49 |
+
ds = load_from_disk(dataset_path)
|
| 50 |
+
|
| 51 |
+
# Case A: DatasetDict with train/val
|
| 52 |
+
if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
|
| 53 |
+
train_ds, val_ds = ds["train"], ds["val"]
|
| 54 |
+
else:
|
| 55 |
+
# Case B: Single dataset with "split" column
|
| 56 |
+
if "split" not in ds.column_names:
|
| 57 |
+
raise ValueError(
|
| 58 |
+
"Dataset must be a DatasetDict(train/val) or have a 'split' column."
|
| 59 |
+
)
|
| 60 |
+
train_ds = ds.filter(lambda x: x["split"] == "train")
|
| 61 |
+
val_ds = ds.filter(lambda x: x["split"] == "val")
|
| 62 |
+
|
| 63 |
+
for required in ["embedding", "label"]:
|
| 64 |
+
if required not in train_ds.column_names:
|
| 65 |
+
raise ValueError(f"Missing column '{required}' in train split.")
|
| 66 |
+
if required not in val_ds.column_names:
|
| 67 |
+
raise ValueError(f"Missing column '{required}' in val split.")
|
| 68 |
+
|
| 69 |
+
X_train = _stack_embeddings(train_ds["embedding"])
|
| 70 |
+
y_train = np.asarray(train_ds["label"], dtype=np.int64)
|
| 71 |
+
|
| 72 |
+
X_val = _stack_embeddings(val_ds["embedding"])
|
| 73 |
+
y_val = np.asarray(val_ds["label"], dtype=np.int64)
|
| 74 |
+
|
| 75 |
+
seq_train = None
|
| 76 |
+
seq_val = None
|
| 77 |
+
if "sequence" in train_ds.column_names:
|
| 78 |
+
seq_train = np.asarray(train_ds["sequence"])
|
| 79 |
+
if "sequence" in val_ds.column_names:
|
| 80 |
+
seq_val = np.asarray(val_ds["sequence"])
|
| 81 |
+
|
| 82 |
+
return SplitData(X_train, y_train, seq_train, X_val, y_val, seq_val)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# -----------------------------
|
| 86 |
+
# Metrics + thresholding
|
| 87 |
+
# -----------------------------
|
| 88 |
+
def best_f1_threshold(y_true: np.ndarray, y_prob: np.ndarray) -> Tuple[float, float]:
|
| 89 |
+
"""
|
| 90 |
+
Find threshold maximizing F1 on the given set.
|
| 91 |
+
Returns (best_threshold, best_f1).
|
| 92 |
+
"""
|
| 93 |
+
precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
|
| 94 |
+
# precision_recall_curve returns thresholds of length n-1
|
| 95 |
+
# compute F1 for those thresholds
|
| 96 |
+
f1s = (2 * precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1] + 1e-12)
|
| 97 |
+
best_idx = int(np.nanargmax(f1s))
|
| 98 |
+
return float(thresholds[best_idx]), float(f1s[best_idx])
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def eval_binary(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> Dict[str, float]:
|
| 102 |
+
y_pred = (y_prob >= threshold).astype(int)
|
| 103 |
+
return {
|
| 104 |
+
"f1": float(f1_score(y_true, y_pred)),
|
| 105 |
+
"auc": float(roc_auc_score(y_true, y_prob)),
|
| 106 |
+
"ap": float(average_precision_score(y_true, y_prob)),
|
| 107 |
+
"threshold": float(threshold),
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# -----------------------------
|
| 112 |
+
# Model factories
|
| 113 |
+
# -----------------------------
|
| 114 |
+
def train_xgb(
|
| 115 |
+
X_train, y_train, X_val, y_val, params: Dict[str, Any]
|
| 116 |
+
) -> Tuple[xgb.Booster, np.ndarray, np.ndarray]:
|
| 117 |
+
dtrain = xgb.DMatrix(X_train, label=y_train)
|
| 118 |
+
dval = xgb.DMatrix(X_val, label=y_val)
|
| 119 |
+
|
| 120 |
+
num_boost_round = int(params.pop("num_boost_round"))
|
| 121 |
+
early_stopping_rounds = int(params.pop("early_stopping_rounds"))
|
| 122 |
+
|
| 123 |
+
booster = xgb.train(
|
| 124 |
+
params=params,
|
| 125 |
+
dtrain=dtrain,
|
| 126 |
+
num_boost_round=num_boost_round,
|
| 127 |
+
evals=[(dval, "val")],
|
| 128 |
+
early_stopping_rounds=early_stopping_rounds,
|
| 129 |
+
verbose_eval=False,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
p_train = booster.predict(dtrain)
|
| 133 |
+
p_val = booster.predict(dval)
|
| 134 |
+
return booster, p_train, p_val
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def train_adaboost(
|
| 138 |
+
X_train, y_train, X_val, y_val, params: Dict[str, Any]
|
| 139 |
+
) -> Tuple[AdaBoostClassifier, np.ndarray, np.ndarray]:
|
| 140 |
+
base_depth = int(params.pop("base_depth"))
|
| 141 |
+
clf = AdaBoostClassifier(
|
| 142 |
+
estimator=DecisionTreeClassifier(max_depth=base_depth),
|
| 143 |
+
n_estimators=int(params["n_estimators"]),
|
| 144 |
+
learning_rate=float(params["learning_rate"]),
|
| 145 |
+
algorithm="SAMME",
|
| 146 |
+
)
|
| 147 |
+
clf.fit(X_train, y_train)
|
| 148 |
+
p_train = clf.predict_proba(X_train)[:, 1]
|
| 149 |
+
p_val = clf.predict_proba(X_val)[:, 1]
|
| 150 |
+
return clf, p_train, p_val
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def train_linearboost(X_train, y_train, X_val, y_val, params):
|
| 154 |
+
clf = LinearBoostClassifier(**params)
|
| 155 |
+
clf.fit(X_train, y_train)
|
| 156 |
+
p_train = clf.predict_proba(X_train)[:, 1]
|
| 157 |
+
p_val = clf.predict_proba(X_val)[:, 1]
|
| 158 |
+
return clf, p_train, p_val
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def suggest_linearboost_params(trial):
|
| 162 |
+
# Core boosting params
|
| 163 |
+
params = {
|
| 164 |
+
"n_estimators": trial.suggest_int("n_estimators", 50, 800),
|
| 165 |
+
"learning_rate": trial.suggest_float("learning_rate", 0.01, 1.0, log=True),
|
| 166 |
+
"algorithm": trial.suggest_categorical("algorithm", ["SAMME.R", "SAMME"]),
|
| 167 |
+
# Scaling choices from docs (you can expand this list if you want)
|
| 168 |
+
"scaler": trial.suggest_categorical(
|
| 169 |
+
"scaler",
|
| 170 |
+
["minmax", "standard", "robust", "quantile-uniform", "quantile-normal", "power"]
|
| 171 |
+
),
|
| 172 |
+
# useful for imbalanced splits
|
| 173 |
+
"class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
|
| 174 |
+
# kernel trick
|
| 175 |
+
"kernel": trial.suggest_categorical("kernel", ["linear", "rbf", "poly", "sigmoid"]),
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
# Kernel-specific params (only when relevant)
|
| 179 |
+
if params["kernel"] in ["rbf", "poly"]:
|
| 180 |
+
params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
|
| 181 |
+
else:
|
| 182 |
+
params["gamma"] = None # docs: default treated as 1/n_features for rbf/poly :contentReference[oaicite:5]{index=5}
|
| 183 |
+
|
| 184 |
+
if params["kernel"] == "poly":
|
| 185 |
+
params["degree"] = trial.suggest_int("degree", 2, 6) # docs default=3 :contentReference[oaicite:6]{index=6}
|
| 186 |
+
params["coef0"] = trial.suggest_float("coef0", 0.0, 5.0) # docs default=1 :contentReference[oaicite:7]{index=7}
|
| 187 |
+
else:
|
| 188 |
+
# safe defaults
|
| 189 |
+
params["degree"] = 3
|
| 190 |
+
params["coef0"] = 1.0
|
| 191 |
+
|
| 192 |
+
return params
|
| 193 |
+
# -----------------------------
|
| 194 |
+
# Saving artifacts
|
| 195 |
+
# -----------------------------
|
| 196 |
+
def save_predictions_csv(
|
| 197 |
+
out_dir: str,
|
| 198 |
+
split_name: str,
|
| 199 |
+
y_true: np.ndarray,
|
| 200 |
+
y_prob: np.ndarray,
|
| 201 |
+
threshold: float,
|
| 202 |
+
sequences: Optional[np.ndarray] = None,
|
| 203 |
+
):
|
| 204 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 205 |
+
df = pd.DataFrame({
|
| 206 |
+
"y_true": y_true.astype(int),
|
| 207 |
+
"y_prob": y_prob.astype(float),
|
| 208 |
+
"y_pred": (y_prob >= threshold).astype(int),
|
| 209 |
+
})
|
| 210 |
+
if sequences is not None:
|
| 211 |
+
df.insert(0, "sequence", sequences)
|
| 212 |
+
df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def plot_curves(out_dir: str, y_true: np.ndarray, y_prob: np.ndarray):
|
| 216 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 217 |
+
|
| 218 |
+
# PR
|
| 219 |
+
precision, recall, _ = precision_recall_curve(y_true, y_prob)
|
| 220 |
+
plt.figure()
|
| 221 |
+
plt.plot(recall, precision)
|
| 222 |
+
plt.xlabel("Recall")
|
| 223 |
+
plt.ylabel("Precision")
|
| 224 |
+
plt.title("Precision-Recall Curve")
|
| 225 |
+
plt.tight_layout()
|
| 226 |
+
plt.savefig(os.path.join(out_dir, "pr_curve.png"))
|
| 227 |
+
plt.close()
|
| 228 |
+
|
| 229 |
+
# ROC
|
| 230 |
+
fpr, tpr, _ = roc_curve(y_true, y_prob)
|
| 231 |
+
plt.figure()
|
| 232 |
+
plt.plot(fpr, tpr)
|
| 233 |
+
plt.xlabel("False Positive Rate")
|
| 234 |
+
plt.ylabel("True Positive Rate")
|
| 235 |
+
plt.title("ROC Curve")
|
| 236 |
+
plt.tight_layout()
|
| 237 |
+
plt.savefig(os.path.join(out_dir, "roc_curve.png"))
|
| 238 |
+
plt.close()
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# -----------------------------
|
| 242 |
+
# Optuna objectives
|
| 243 |
+
# -----------------------------
|
| 244 |
+
def make_objective(model_name: str, data: SplitData, out_dir: str):
|
| 245 |
+
Xtr, ytr, Xva, yva = data.X_train, data.y_train, data.X_val, data.y_val
|
| 246 |
+
|
| 247 |
+
def objective(trial: optuna.Trial) -> float:
|
| 248 |
+
if model_name == "xgb":
|
| 249 |
+
params = {
|
| 250 |
+
"objective": "binary:logistic",
|
| 251 |
+
"eval_metric": "logloss",
|
| 252 |
+
"lambda": trial.suggest_float("lambda", 1e-8, 50.0, log=True),
|
| 253 |
+
"alpha": trial.suggest_float("alpha", 1e-8, 50.0, log=True),
|
| 254 |
+
"colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
|
| 255 |
+
"subsample": trial.suggest_float("subsample", 0.5, 1.0),
|
| 256 |
+
"learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True),
|
| 257 |
+
"max_depth": trial.suggest_int("max_depth", 2, 15),
|
| 258 |
+
"min_child_weight": trial.suggest_int("min_child_weight", 1, 500),
|
| 259 |
+
"gamma": trial.suggest_float("gamma", 0.0, 10.0),
|
| 260 |
+
"tree_method": "hist",
|
| 261 |
+
"device": "cuda",
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
# Optional GPU: set env CUDA_VISIBLE_DEVICES externally if you want.
|
| 265 |
+
# If you *know* you want GPU and your xgboost supports it:
|
| 266 |
+
# params["device"] = "cuda"
|
| 267 |
+
|
| 268 |
+
params["num_boost_round"] = trial.suggest_int("num_boost_round", 50, 1500)
|
| 269 |
+
params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 20, 200)
|
| 270 |
+
|
| 271 |
+
model, p_tr, p_va = train_xgb(Xtr, ytr, Xva, yva, params.copy())
|
| 272 |
+
|
| 273 |
+
elif model_name == "adaboost":
|
| 274 |
+
params = {
|
| 275 |
+
"n_estimators": trial.suggest_int("n_estimators", 50, 800),
|
| 276 |
+
"learning_rate": trial.suggest_float("learning_rate", 1e-3, 2.0, log=True),
|
| 277 |
+
"base_depth": trial.suggest_int("base_depth", 1, 4),
|
| 278 |
+
}
|
| 279 |
+
model, p_tr, p_va = train_adaboost(Xtr, ytr, Xva, yva, params)
|
| 280 |
+
|
| 281 |
+
elif model_name == "linearboost":
|
| 282 |
+
params = suggest_linearboost_params(trial)
|
| 283 |
+
model, p_tr, p_va = train_linearboost(Xtr, ytr, Xva, yva, params)
|
| 284 |
+
else:
|
| 285 |
+
raise ValueError(f"Unknown model_name={model_name}")
|
| 286 |
+
|
| 287 |
+
# Threshold picked on val for fair comparison across models
|
| 288 |
+
thr, f1_at_thr = best_f1_threshold(yva, p_va)
|
| 289 |
+
metrics = eval_binary(yva, p_va, thr)
|
| 290 |
+
|
| 291 |
+
# Track best trial artifacts inside the study directory
|
| 292 |
+
trial.set_user_attr("threshold", thr)
|
| 293 |
+
trial.set_user_attr("auc", metrics["auc"])
|
| 294 |
+
trial.set_user_attr("ap", metrics["ap"])
|
| 295 |
+
|
| 296 |
+
return f1_at_thr
|
| 297 |
+
|
| 298 |
+
return objective
|
| 299 |
+
|
| 300 |
+
# -----------------------------
|
| 301 |
+
# Main runner
|
| 302 |
+
# -----------------------------
|
| 303 |
+
def run_optuna_and_refit(
|
| 304 |
+
dataset_path: str,
|
| 305 |
+
out_dir: str,
|
| 306 |
+
model_name: str,
|
| 307 |
+
n_trials: int = 200,
|
| 308 |
+
):
|
| 309 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 310 |
+
|
| 311 |
+
data = load_split_data(dataset_path)
|
| 312 |
+
print(f"[Data] Train: {data.X_train.shape}, Val: {data.X_val.shape}")
|
| 313 |
+
|
| 314 |
+
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
|
| 315 |
+
study.optimize(make_objective(model_name, data, out_dir), n_trials=n_trials)
|
| 316 |
+
|
| 317 |
+
# Save trials table
|
| 318 |
+
trials_df = study.trials_dataframe()
|
| 319 |
+
trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
|
| 320 |
+
|
| 321 |
+
best = study.best_trial
|
| 322 |
+
best_params = dict(best.params)
|
| 323 |
+
best_thr = float(best.user_attrs["threshold"])
|
| 324 |
+
best_auc = float(best.user_attrs["auc"])
|
| 325 |
+
best_ap = float(best.user_attrs["ap"])
|
| 326 |
+
best_f1 = float(best.value)
|
| 327 |
+
|
| 328 |
+
# Refit best model on train (same protocol as objective)
|
| 329 |
+
if model_name == "xgb":
|
| 330 |
+
# Reconstruct full param dict
|
| 331 |
+
params = {
|
| 332 |
+
"objective": "binary:logistic",
|
| 333 |
+
"eval_metric": "logloss",
|
| 334 |
+
"lambda": best_params["lambda"],
|
| 335 |
+
"alpha": best_params["alpha"],
|
| 336 |
+
"colsample_bytree": best_params["colsample_bytree"],
|
| 337 |
+
"subsample": best_params["subsample"],
|
| 338 |
+
"learning_rate": best_params["learning_rate"],
|
| 339 |
+
"max_depth": best_params["max_depth"],
|
| 340 |
+
"min_child_weight": best_params["min_child_weight"],
|
| 341 |
+
"gamma": best_params["gamma"],
|
| 342 |
+
"tree_method": "hist",
|
| 343 |
+
"num_boost_round": best_params["num_boost_round"],
|
| 344 |
+
"early_stopping_rounds": best_params["early_stopping_rounds"],
|
| 345 |
+
}
|
| 346 |
+
model, p_tr, p_va = train_xgb(
|
| 347 |
+
data.X_train, data.y_train, data.X_val, data.y_val, params
|
| 348 |
+
)
|
| 349 |
+
model_path = os.path.join(out_dir, "best_model.json")
|
| 350 |
+
model.save_model(model_path)
|
| 351 |
+
|
| 352 |
+
elif model_name == "adaboost":
|
| 353 |
+
params = best_params
|
| 354 |
+
model, p_tr, p_va = train_adaboost(
|
| 355 |
+
data.X_train, data.y_train, data.X_val, data.y_val, params
|
| 356 |
+
)
|
| 357 |
+
model_path = os.path.join(out_dir, "best_model.joblib")
|
| 358 |
+
joblib.dump(model, model_path)
|
| 359 |
+
|
| 360 |
+
elif model_name == "linearboost":
|
| 361 |
+
params = best_params
|
| 362 |
+
|
| 363 |
+
model, p_tr, p_va = train_linearboost(
|
| 364 |
+
data.X_train, data.y_train, data.X_val, data.y_val, params
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
model_path = os.path.join(out_dir, "best_model.joblib")
|
| 368 |
+
joblib.dump(model, model_path)
|
| 369 |
+
else:
|
| 370 |
+
raise ValueError(model_name)
|
| 371 |
+
|
| 372 |
+
# Save predictions CSVs
|
| 373 |
+
save_predictions_csv(out_dir, "train", data.y_train, p_tr, best_thr, data.seq_train)
|
| 374 |
+
save_predictions_csv(out_dir, "val", data.y_val, p_va, best_thr, data.seq_val)
|
| 375 |
+
|
| 376 |
+
# Plots on val
|
| 377 |
+
plot_curves(out_dir, data.y_val, p_va)
|
| 378 |
+
|
| 379 |
+
# Summary
|
| 380 |
+
summary = [
|
| 381 |
+
"=" * 72,
|
| 382 |
+
f"MODEL: {model_name}",
|
| 383 |
+
f"Best trial: {best.number}",
|
| 384 |
+
f"Best F1 (val @ best-threshold): {best_f1:.4f}",
|
| 385 |
+
f"Val AUC: {best_auc:.4f}",
|
| 386 |
+
f"Val AP: {best_ap:.4f}",
|
| 387 |
+
f"Best threshold (picked on val): {best_thr:.4f}",
|
| 388 |
+
f"Model saved to: {model_path}",
|
| 389 |
+
"Best params:",
|
| 390 |
+
json.dumps(best_params, indent=2),
|
| 391 |
+
"=" * 72,
|
| 392 |
+
]
|
| 393 |
+
with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
|
| 394 |
+
f.write("\n".join(summary))
|
| 395 |
+
print("\n".join(summary))
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
if __name__ == "__main__":
|
| 399 |
+
# Example usage:
|
| 400 |
+
# dataset_path = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_classifiers/data/solubility"
|
| 401 |
+
# out_dir = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_classifiers/src/solubility/xgb"
|
| 402 |
+
# run_optuna_and_refit(dataset_path, out_dir, model_name="xgb", n_trials=200)
|
| 403 |
+
|
| 404 |
+
import argparse
|
| 405 |
+
parser = argparse.ArgumentParser()
|
| 406 |
+
parser.add_argument("--dataset_path", type=str, required=True)
|
| 407 |
+
parser.add_argument("--out_dir", type=str, required=True)
|
| 408 |
+
parser.add_argument("--model", type=str, choices=["xgb", "adaboost", "linearboost"], required=True)
|
| 409 |
+
parser.add_argument("--n_trials", type=int, default=200)
|
| 410 |
+
args = parser.parse_args()
|
| 411 |
+
|
| 412 |
+
run_optuna_and_refit(
|
| 413 |
+
dataset_path=args.dataset_path,
|
| 414 |
+
out_dir=args.out_dir,
|
| 415 |
+
model_name=args.model,
|
| 416 |
+
n_trials=args.n_trials,
|
| 417 |
+
)
|
training_classifiers/.ipynb_checkpoints/train_ml-checkpoint.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import joblib
|
| 4 |
+
import optuna
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Dict, Any, Tuple, Optional
|
| 10 |
+
from datasets import load_from_disk, DatasetDict
|
| 11 |
+
from sklearn.metrics import (
|
| 12 |
+
f1_score, roc_auc_score, average_precision_score,
|
| 13 |
+
precision_recall_curve, roc_curve
|
| 14 |
+
)
|
| 15 |
+
from sklearn.linear_model import LogisticRegression
|
| 16 |
+
from sklearn.svm import SVC, LinearSVC
|
| 17 |
+
from sklearn.calibration import CalibratedClassifierCV
|
| 18 |
+
import torch
|
| 19 |
+
import time
|
| 20 |
+
import xgboost as xgb
|
| 21 |
+
from lightning.pytorch import seed_everything
|
| 22 |
+
import cupy as cp
|
| 23 |
+
from cuml.svm import SVC as cuSVC
|
| 24 |
+
from cuml.linear_model import LogisticRegression as cuLogReg
|
| 25 |
+
seed_everything(1986)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def to_gpu(X: np.ndarray):
|
| 29 |
+
if isinstance(X, cp.ndarray):
|
| 30 |
+
return X
|
| 31 |
+
return cp.asarray(X, dtype=cp.float32)
|
| 32 |
+
|
| 33 |
+
def to_cpu(x):
|
| 34 |
+
if isinstance(x, cp.ndarray):
|
| 35 |
+
return cp.asnumpy(x)
|
| 36 |
+
return np.asarray(x
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class SplitData:
|
| 40 |
+
X_train: np.ndarray
|
| 41 |
+
y_train: np.ndarray
|
| 42 |
+
seq_train: Optional[np.ndarray]
|
| 43 |
+
X_val: np.ndarray
|
| 44 |
+
y_val: np.ndarray
|
| 45 |
+
seq_val: Optional[np.ndarray]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _stack_embeddings(col) -> np.ndarray:
|
| 49 |
+
arr = np.asarray(col, dtype=np.float32)
|
| 50 |
+
if arr.ndim != 2:
|
| 51 |
+
arr = np.stack(col).astype(np.float32)
|
| 52 |
+
return arr
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def load_split_data(dataset_path: str) -> SplitData:
|
| 56 |
+
ds = load_from_disk(dataset_path)
|
| 57 |
+
|
| 58 |
+
# Case A: DatasetDict with train/val
|
| 59 |
+
if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
|
| 60 |
+
train_ds, val_ds = ds["train"], ds["val"]
|
| 61 |
+
else:
|
| 62 |
+
# Case B: Single dataset with "split" column
|
| 63 |
+
if "split" not in ds.column_names:
|
| 64 |
+
raise ValueError(
|
| 65 |
+
"Dataset must be a DatasetDict(train/val) or have a 'split' column."
|
| 66 |
+
)
|
| 67 |
+
train_ds = ds.filter(lambda x: x["split"] == "train")
|
| 68 |
+
val_ds = ds.filter(lambda x: x["split"] == "val")
|
| 69 |
+
|
| 70 |
+
for required in ["embedding", "label"]:
|
| 71 |
+
if required not in train_ds.column_names:
|
| 72 |
+
raise ValueError(f"Missing column '{required}' in train split.")
|
| 73 |
+
if required not in val_ds.column_names:
|
| 74 |
+
raise ValueError(f"Missing column '{required}' in val split.")
|
| 75 |
+
|
| 76 |
+
X_train = _stack_embeddings(train_ds["embedding"])
|
| 77 |
+
y_train = np.asarray(train_ds["label"], dtype=np.int64)
|
| 78 |
+
|
| 79 |
+
X_val = _stack_embeddings(val_ds["embedding"])
|
| 80 |
+
y_val = np.asarray(val_ds["label"], dtype=np.int64)
|
| 81 |
+
|
| 82 |
+
seq_train = None
|
| 83 |
+
seq_val = None
|
| 84 |
+
if "sequence" in train_ds.column_names:
|
| 85 |
+
seq_train = np.asarray(train_ds["sequence"])
|
| 86 |
+
if "sequence" in val_ds.column_names:
|
| 87 |
+
seq_val = np.asarray(val_ds["sequence"])
|
| 88 |
+
|
| 89 |
+
return SplitData(X_train, y_train, seq_train, X_val, y_val, seq_val)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def best_f1_threshold(y_true: np.ndarray, y_prob: np.ndarray) -> Tuple[float, float]:
|
| 93 |
+
"""
|
| 94 |
+
Find threshold maximizing F1 on the given set.
|
| 95 |
+
Returns (best_threshold, best_f1).
|
| 96 |
+
"""
|
| 97 |
+
precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
|
| 98 |
+
f1s = (2 * precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1] + 1e-12)
|
| 99 |
+
best_idx = int(np.nanargmax(f1s))
|
| 100 |
+
return float(thresholds[best_idx]), float(f1s[best_idx])
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def eval_binary(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> Dict[str, float]:
|
| 104 |
+
y_pred = (y_prob >= threshold).astype(int)
|
| 105 |
+
return {
|
| 106 |
+
"f1": float(f1_score(y_true, y_pred)),
|
| 107 |
+
"auc": float(roc_auc_score(y_true, y_prob)),
|
| 108 |
+
"ap": float(average_precision_score(y_true, y_prob)),
|
| 109 |
+
"threshold": float(threshold),
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# -----------------------------
|
| 114 |
+
# Model
|
| 115 |
+
# -----------------------------
|
| 116 |
+
def train_xgb(
|
| 117 |
+
X_train, y_train, X_val, y_val, params: Dict[str, Any]
|
| 118 |
+
) -> Tuple[xgb.Booster, np.ndarray, np.ndarray]:
|
| 119 |
+
dtrain = xgb.DMatrix(X_train, label=y_train)
|
| 120 |
+
dval = xgb.DMatrix(X_val, label=y_val)
|
| 121 |
+
|
| 122 |
+
num_boost_round = int(params.pop("num_boost_round"))
|
| 123 |
+
early_stopping_rounds = int(params.pop("early_stopping_rounds"))
|
| 124 |
+
|
| 125 |
+
booster = xgb.train(
|
| 126 |
+
params=params,
|
| 127 |
+
dtrain=dtrain,
|
| 128 |
+
num_boost_round=num_boost_round,
|
| 129 |
+
evals=[(dval, "val")],
|
| 130 |
+
early_stopping_rounds=early_stopping_rounds,
|
| 131 |
+
verbose_eval=False,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
p_train = booster.predict(dtrain)
|
| 135 |
+
p_val = booster.predict(dval)
|
| 136 |
+
return booster, p_train, p_val
|
| 137 |
+
|
| 138 |
+
def train_cuml_svc(X_train, y_train, X_val, y_val, params):
|
| 139 |
+
Xtr = to_gpu(X_train)
|
| 140 |
+
Xva = to_gpu(X_val)
|
| 141 |
+
ytr = to_gpu(y_train).astype(cp.int32)
|
| 142 |
+
|
| 143 |
+
clf = cuSVC(
|
| 144 |
+
C=float(params["C"]),
|
| 145 |
+
kernel=params["kernel"],
|
| 146 |
+
gamma=params.get("gamma", "scale"),
|
| 147 |
+
class_weight=params.get("class_weight", None),
|
| 148 |
+
probability=bool(params.get("probability", True)),
|
| 149 |
+
random_state=1986,
|
| 150 |
+
max_iter=int(params.get("max_iter", 1000)),
|
| 151 |
+
tol=float(params.get("tol", 1e-4)),
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
clf.fit(Xtr, ytr)
|
| 155 |
+
|
| 156 |
+
p_train = to_cpu(clf.predict_proba(Xtr)[:, 1])
|
| 157 |
+
p_val = to_cpu(clf.predict_proba(Xva)[:, 1])
|
| 158 |
+
return clf, p_train, p_val
|
| 159 |
+
|
| 160 |
+
def train_cuml_elastic_net(X_train, y_train, X_val, y_val, params):
|
| 161 |
+
Xtr = to_gpu(X_train)
|
| 162 |
+
Xva = to_gpu(X_val)
|
| 163 |
+
ytr = to_gpu(y_train).astype(cp.int32)
|
| 164 |
+
|
| 165 |
+
clf = cuLogReg(
|
| 166 |
+
penalty="elasticnet",
|
| 167 |
+
C=float(params["C"]),
|
| 168 |
+
l1_ratio=float(params["l1_ratio"]),
|
| 169 |
+
class_weight=params.get("class_weight", None),
|
| 170 |
+
max_iter=int(params.get("max_iter", 1000)),
|
| 171 |
+
tol=float(params.get("tol", 1e-4)),
|
| 172 |
+
solver="qn",
|
| 173 |
+
fit_intercept=True,
|
| 174 |
+
)
|
| 175 |
+
clf.fit(Xtr, ytr)
|
| 176 |
+
|
| 177 |
+
p_train = to_cpu(clf.predict_proba(Xtr)[:, 1])
|
| 178 |
+
p_val = to_cpu(clf.predict_proba(Xva)[:, 1])
|
| 179 |
+
return clf, p_train, p_val
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def train_svm(X_train, y_train, X_val, y_val, params):
|
| 183 |
+
"""
|
| 184 |
+
Kernel SVM via SVC. CPU only in sklearn.
|
| 185 |
+
probability=True enables predict_proba but is slower.
|
| 186 |
+
"""
|
| 187 |
+
clf = SVC(
|
| 188 |
+
C=float(params["C"]),
|
| 189 |
+
kernel=params["kernel"],
|
| 190 |
+
gamma=params.get("gamma", "scale"),
|
| 191 |
+
class_weight=params.get("class_weight", None),
|
| 192 |
+
probability=True,
|
| 193 |
+
random_state=1986,
|
| 194 |
+
)
|
| 195 |
+
clf.fit(X_train, y_train)
|
| 196 |
+
p_train = clf.predict_proba(X_train)[:, 1]
|
| 197 |
+
p_val = clf.predict_proba(X_val)[:, 1]
|
| 198 |
+
return clf, p_train, p_val
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def train_linearsvm_calibrated(X_train, y_train, X_val, y_val, params):
|
| 202 |
+
"""
|
| 203 |
+
Fast linear SVM (LinearSVC) + probability calibration.
|
| 204 |
+
Usually much faster than SVC on large datasets.
|
| 205 |
+
"""
|
| 206 |
+
base = LinearSVC(
|
| 207 |
+
C=float(params["C"]),
|
| 208 |
+
class_weight=params.get("class_weight", None),
|
| 209 |
+
max_iter=int(params.get("max_iter", 5000)),
|
| 210 |
+
random_state=1986,
|
| 211 |
+
)
|
| 212 |
+
# calibration to get probabilities for PR/ROC + thresholding
|
| 213 |
+
clf = CalibratedClassifierCV(base, method="sigmoid", cv=3)
|
| 214 |
+
clf.fit(X_train, y_train)
|
| 215 |
+
p_train = clf.predict_proba(X_train)[:, 1]
|
| 216 |
+
p_val = clf.predict_proba(X_val)[:, 1]
|
| 217 |
+
return clf, p_train, p_val
|
| 218 |
+
|
| 219 |
+
# -----------------------------
|
| 220 |
+
# Saving artifacts
|
| 221 |
+
# -----------------------------
|
| 222 |
+
def save_predictions_csv(
|
| 223 |
+
out_dir: str,
|
| 224 |
+
split_name: str,
|
| 225 |
+
y_true: np.ndarray,
|
| 226 |
+
y_prob: np.ndarray,
|
| 227 |
+
threshold: float,
|
| 228 |
+
sequences: Optional[np.ndarray] = None,
|
| 229 |
+
):
|
| 230 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 231 |
+
df = pd.DataFrame({
|
| 232 |
+
"y_true": y_true.astype(int),
|
| 233 |
+
"y_prob": y_prob.astype(float),
|
| 234 |
+
"y_pred": (y_prob >= threshold).astype(int),
|
| 235 |
+
})
|
| 236 |
+
if sequences is not None:
|
| 237 |
+
df.insert(0, "sequence", sequences)
|
| 238 |
+
df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def plot_curves(out_dir: str, y_true: np.ndarray, y_prob: np.ndarray):
|
| 242 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 243 |
+
|
| 244 |
+
# PR
|
| 245 |
+
precision, recall, _ = precision_recall_curve(y_true, y_prob)
|
| 246 |
+
plt.figure()
|
| 247 |
+
plt.plot(recall, precision)
|
| 248 |
+
plt.xlabel("Recall")
|
| 249 |
+
plt.ylabel("Precision")
|
| 250 |
+
plt.title("Precision-Recall Curve")
|
| 251 |
+
plt.tight_layout()
|
| 252 |
+
plt.savefig(os.path.join(out_dir, "pr_curve.png"))
|
| 253 |
+
plt.close()
|
| 254 |
+
|
| 255 |
+
# ROC
|
| 256 |
+
fpr, tpr, _ = roc_curve(y_true, y_prob)
|
| 257 |
+
plt.figure()
|
| 258 |
+
plt.plot(fpr, tpr)
|
| 259 |
+
plt.xlabel("False Positive Rate")
|
| 260 |
+
plt.ylabel("True Positive Rate")
|
| 261 |
+
plt.title("ROC Curve")
|
| 262 |
+
plt.tight_layout()
|
| 263 |
+
plt.savefig(os.path.join(out_dir, "roc_curve.png"))
|
| 264 |
+
plt.close()
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# -----------------------------
|
| 268 |
+
# Optuna objectives
|
| 269 |
+
# -----------------------------
|
| 270 |
+
def make_objective(model_name: str, data: SplitData, out_dir: str):
|
| 271 |
+
Xtr, ytr, Xva, yva = data.X_train, data.y_train, data.X_val, data.y_val
|
| 272 |
+
|
| 273 |
+
def objective(trial: optuna.Trial) -> float:
|
| 274 |
+
if model_name == "xgb":
|
| 275 |
+
params = {
|
| 276 |
+
"objective": "binary:logistic",
|
| 277 |
+
"eval_metric": "logloss",
|
| 278 |
+
"lambda": trial.suggest_float("lambda", 1e-8, 50.0, log=True),
|
| 279 |
+
"alpha": trial.suggest_float("alpha", 1e-8, 50.0, log=True),
|
| 280 |
+
"colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
|
| 281 |
+
"subsample": trial.suggest_float("subsample", 0.5, 1.0),
|
| 282 |
+
"learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True),
|
| 283 |
+
"max_depth": trial.suggest_int("max_depth", 2, 15),
|
| 284 |
+
"min_child_weight": trial.suggest_int("min_child_weight", 1, 500),
|
| 285 |
+
"gamma": trial.suggest_float("gamma", 0.0, 10.0),
|
| 286 |
+
"tree_method": "hist",
|
| 287 |
+
"device": "cuda",
|
| 288 |
+
}
|
| 289 |
+
params["num_boost_round"] = trial.suggest_int("num_boost_round", 50, 1500)
|
| 290 |
+
params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 20, 200)
|
| 291 |
+
|
| 292 |
+
model, p_tr, p_va = train_xgb(Xtr, ytr, Xva, yva, params.copy())
|
| 293 |
+
|
| 294 |
+
elif model_name == "svm":
|
| 295 |
+
svm_kind = trial.suggest_categorical("svm_kind", ["svc", "linear_calibrated"])
|
| 296 |
+
|
| 297 |
+
if svm_kind == "svc":
|
| 298 |
+
params = {
|
| 299 |
+
"C": trial.suggest_float("C", 1e-3, 1e3, log=True),
|
| 300 |
+
"kernel": trial.suggest_categorical("kernel", ["rbf", "linear", "poly", "sigmoid"]),
|
| 301 |
+
"class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
|
| 302 |
+
}
|
| 303 |
+
if params["kernel"] in ["rbf", "poly", "sigmoid"]:
|
| 304 |
+
params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
|
| 305 |
+
else:
|
| 306 |
+
params["gamma"] = "scale"
|
| 307 |
+
|
| 308 |
+
model, p_tr, p_va = train_svm(Xtr, ytr, Xva, yva, params)
|
| 309 |
+
|
| 310 |
+
else:
|
| 311 |
+
params = {
|
| 312 |
+
"C": trial.suggest_float("C", 1e-3, 1e3, log=True),
|
| 313 |
+
"class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
|
| 314 |
+
"max_iter": trial.suggest_int("max_iter", 2000, 20000),
|
| 315 |
+
}
|
| 316 |
+
model, p_tr, p_va = train_linearsvm_calibrated(Xtr, ytr, Xva, yva, params)
|
| 317 |
+
elif model_name == "svm_gpu":
|
| 318 |
+
params = {
|
| 319 |
+
"C": trial.suggest_float("C", 1e-3, 1e3, log=True),
|
| 320 |
+
"kernel": trial.suggest_categorical("kernel", ["rbf", "linear", "poly", "sigmoid"]),
|
| 321 |
+
"class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
|
| 322 |
+
"probability": True,
|
| 323 |
+
"max_iter": trial.suggest_int("max_iter", 200, 5000),
|
| 324 |
+
"tol": trial.suggest_float("tol", 1e-6, 1e-2, log=True),
|
| 325 |
+
}
|
| 326 |
+
if params["kernel"] in ["rbf", "poly", "sigmoid"]:
|
| 327 |
+
params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
|
| 328 |
+
else:
|
| 329 |
+
params["gamma"] = "scale"
|
| 330 |
+
|
| 331 |
+
model, p_tr, p_va = train_cuml_svc(Xtr, ytr, Xva, yva, params)
|
| 332 |
+
|
| 333 |
+
elif model_name == "enet_gpu":
|
| 334 |
+
params = {
|
| 335 |
+
"C": trial.suggest_float("C", 1e-4, 1e3, log=True),
|
| 336 |
+
"l1_ratio": trial.suggest_float("l1_ratio", 0.0, 1.0),
|
| 337 |
+
"class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
|
| 338 |
+
"max_iter": trial.suggest_int("max_iter", 200, 5000),
|
| 339 |
+
"tol": trial.suggest_float("tol", 1e-6, 1e-2, log=True),
|
| 340 |
+
}
|
| 341 |
+
model, p_tr, p_va = train_cuml_elastic_net(Xtr, ytr, Xva, yva, params)
|
| 342 |
+
else:
|
| 343 |
+
raise ValueError(f"Unknown model_name={model_name}")
|
| 344 |
+
|
| 345 |
+
thr, f1_at_thr = best_f1_threshold(yva, p_va)
|
| 346 |
+
metrics = eval_binary(yva, p_va, thr)
|
| 347 |
+
trial.set_user_attr("threshold", thr)
|
| 348 |
+
trial.set_user_attr("auc", metrics["auc"])
|
| 349 |
+
trial.set_user_attr("ap", metrics["ap"])
|
| 350 |
+
return f1_at_thr
|
| 351 |
+
|
| 352 |
+
return objective
|
| 353 |
+
|
| 354 |
+
# -----------------------------
|
| 355 |
+
# Main
|
| 356 |
+
# -----------------------------
|
| 357 |
+
def run_optuna_and_refit(
|
| 358 |
+
dataset_path: str,
|
| 359 |
+
out_dir: str,
|
| 360 |
+
model_name: str,
|
| 361 |
+
n_trials: int = 200,
|
| 362 |
+
):
|
| 363 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 364 |
+
|
| 365 |
+
data = load_split_data(dataset_path)
|
| 366 |
+
print(f"[Data] Train: {data.X_train.shape}, Val: {data.X_val.shape}")
|
| 367 |
+
|
| 368 |
+
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
|
| 369 |
+
study.optimize(make_objective(model_name, data, out_dir), n_trials=n_trials)
|
| 370 |
+
|
| 371 |
+
trials_df = study.trials_dataframe()
|
| 372 |
+
trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
|
| 373 |
+
|
| 374 |
+
best = study.best_trial
|
| 375 |
+
best_params = dict(best.params)
|
| 376 |
+
best_thr = float(best.user_attrs["threshold"])
|
| 377 |
+
best_auc = float(best.user_attrs["auc"])
|
| 378 |
+
best_ap = float(best.user_attrs["ap"])
|
| 379 |
+
best_f1 = float(best.value)
|
| 380 |
+
|
| 381 |
+
# Refit best model on train
|
| 382 |
+
if model_name == "xgb":
|
| 383 |
+
params = {
|
| 384 |
+
"objective": "binary:logistic",
|
| 385 |
+
"eval_metric": "logloss",
|
| 386 |
+
"lambda": best_params["lambda"],
|
| 387 |
+
"alpha": best_params["alpha"],
|
| 388 |
+
"colsample_bytree": best_params["colsample_bytree"],
|
| 389 |
+
"subsample": best_params["subsample"],
|
| 390 |
+
"learning_rate": best_params["learning_rate"],
|
| 391 |
+
"max_depth": best_params["max_depth"],
|
| 392 |
+
"min_child_weight": best_params["min_child_weight"],
|
| 393 |
+
"gamma": best_params["gamma"],
|
| 394 |
+
"tree_method": "hist",
|
| 395 |
+
"num_boost_round": best_params["num_boost_round"],
|
| 396 |
+
"early_stopping_rounds": best_params["early_stopping_rounds"],
|
| 397 |
+
}
|
| 398 |
+
model, p_tr, p_va = train_xgb(
|
| 399 |
+
data.X_train, data.y_train, data.X_val, data.y_val, params
|
| 400 |
+
)
|
| 401 |
+
model_path = os.path.join(out_dir, "best_model.json")
|
| 402 |
+
model.save_model(model_path)
|
| 403 |
+
|
| 404 |
+
elif model_name == "svm":
|
| 405 |
+
svm_kind = best_params["svm_kind"]
|
| 406 |
+
if svm_kind == "svc":
|
| 407 |
+
model, p_tr, p_va = train_svm(data.X_train, data.y_train, data.X_val, data.y_val, best_params)
|
| 408 |
+
else:
|
| 409 |
+
model, p_tr, p_va = train_linearsvm_calibrated(data.X_train, data.y_train, data.X_val, data.y_val, best_params)
|
| 410 |
+
|
| 411 |
+
model_path = os.path.join(out_dir, "best_model.joblib")
|
| 412 |
+
joblib.dump(model, model_path)
|
| 413 |
+
elif model_name == "svm_gpu":
|
| 414 |
+
model, p_tr, p_va = train_cuml_svc(
|
| 415 |
+
data.X_train, data.y_train, data.X_val, data.y_val, best_params
|
| 416 |
+
)
|
| 417 |
+
model_path = os.path.join(out_dir, "best_model_cuml_svc.joblib")
|
| 418 |
+
joblib.dump(model, model_path)
|
| 419 |
+
|
| 420 |
+
elif model_name == "enet_gpu":
|
| 421 |
+
model, p_tr, p_va = train_cuml_elastic_net(
|
| 422 |
+
data.X_train, data.y_train, data.X_val, data.y_val, best_params
|
| 423 |
+
)
|
| 424 |
+
model_path = os.path.join(out_dir, "best_model_cuml_enet.joblib")
|
| 425 |
+
joblib.dump(model, model_path)
|
| 426 |
+
else:
|
| 427 |
+
raise ValueError(model_name)
|
| 428 |
+
|
| 429 |
+
# Save predictions CSVs
|
| 430 |
+
save_predictions_csv(out_dir, "train", data.y_train, p_tr, best_thr, data.seq_train)
|
| 431 |
+
save_predictions_csv(out_dir, "val", data.y_val, p_va, best_thr, data.seq_val)
|
| 432 |
+
|
| 433 |
+
# Plots on val
|
| 434 |
+
plot_curves(out_dir, data.y_val, p_va)
|
| 435 |
+
|
| 436 |
+
summary = [
|
| 437 |
+
"=" * 72,
|
| 438 |
+
f"MODEL: {model_name}",
|
| 439 |
+
f"Best trial: {best.number}",
|
| 440 |
+
f"Best F1 (val @ best-threshold): {best_f1:.4f}",
|
| 441 |
+
f"Val AUC: {best_auc:.4f}",
|
| 442 |
+
f"Val AP: {best_ap:.4f}",
|
| 443 |
+
f"Best threshold (picked on val): {best_thr:.4f}",
|
| 444 |
+
f"Model saved to: {model_path}",
|
| 445 |
+
"Best params:",
|
| 446 |
+
json.dumps(best_params, indent=2),
|
| 447 |
+
"=" * 72,
|
| 448 |
+
]
|
| 449 |
+
with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
|
| 450 |
+
f.write("\n".join(summary))
|
| 451 |
+
print("\n".join(summary))
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
if __name__ == "__main__":
|
| 455 |
+
import argparse
|
| 456 |
+
parser = argparse.ArgumentParser()
|
| 457 |
+
parser.add_argument("--dataset_path", type=str, required=True)
|
| 458 |
+
parser.add_argument("--out_dir", type=str, required=True)
|
| 459 |
+
parser.add_argument("--model", type=str, choices=["xgb", "svm_gpu", "enet_gpu"], required=True)
|
| 460 |
+
parser.add_argument("--n_trials", type=int, default=200)
|
| 461 |
+
args = parser.parse_args()
|
| 462 |
+
|
| 463 |
+
run_optuna_and_refit(
|
| 464 |
+
dataset_path=args.dataset_path,
|
| 465 |
+
out_dir=args.out_dir,
|
| 466 |
+
model_name=args.model,
|
| 467 |
+
n_trials=args.n_trials,
|
| 468 |
+
)
|
training_classifiers/.ipynb_checkpoints/train_ml_regression-checkpoint.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import joblib
|
| 4 |
+
import optuna
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Dict, Any, Tuple, Optional
|
| 10 |
+
from datasets import load_from_disk, DatasetDict
|
| 11 |
+
from sklearn.preprocessing import StandardScaler
|
| 12 |
+
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
| 13 |
+
from sklearn.svm import SVR
|
| 14 |
+
import xgboost as xgb
|
| 15 |
+
from lightning.pytorch import seed_everything
|
| 16 |
+
import cupy as cp
|
| 17 |
+
from cuml.linear_model import ElasticNet as cuElasticNet
|
| 18 |
+
from scipy.stats import spearmanr
|
| 19 |
+
seed_everything(1986)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# -----------------------------
|
| 23 |
+
# GPU/CPU helpers
|
| 24 |
+
# -----------------------------
|
| 25 |
+
def to_gpu(X: np.ndarray):
|
| 26 |
+
if isinstance(X, cp.ndarray):
|
| 27 |
+
return X
|
| 28 |
+
return cp.asarray(X, dtype=cp.float32)
|
| 29 |
+
|
| 30 |
+
def to_cpu(x):
|
| 31 |
+
if isinstance(x, cp.ndarray):
|
| 32 |
+
return cp.asnumpy(x)
|
| 33 |
+
return np.asarray(x)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# -----------------------------
|
| 37 |
+
# Data loading
|
| 38 |
+
# -----------------------------
|
| 39 |
+
@dataclass
|
| 40 |
+
class SplitData:
|
| 41 |
+
X_train: np.ndarray
|
| 42 |
+
y_train: np.ndarray
|
| 43 |
+
seq_train: Optional[np.ndarray]
|
| 44 |
+
X_val: np.ndarray
|
| 45 |
+
y_val: np.ndarray
|
| 46 |
+
seq_val: Optional[np.ndarray]
|
| 47 |
+
|
| 48 |
+
def _stack_embeddings(col) -> np.ndarray:
|
| 49 |
+
arr = np.asarray(col, dtype=np.float32)
|
| 50 |
+
if arr.ndim != 2:
|
| 51 |
+
arr = np.stack(col).astype(np.float32)
|
| 52 |
+
return arr
|
| 53 |
+
|
| 54 |
+
def load_split_data(dataset_path: str) -> SplitData:
|
| 55 |
+
ds = load_from_disk(dataset_path)
|
| 56 |
+
|
| 57 |
+
if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
|
| 58 |
+
train_ds, val_ds = ds["train"], ds["val"]
|
| 59 |
+
else:
|
| 60 |
+
if "split" not in ds.column_names:
|
| 61 |
+
raise ValueError("Dataset must be a DatasetDict(train/val) or have a 'split' column.")
|
| 62 |
+
train_ds = ds.filter(lambda x: x["split"] == "train")
|
| 63 |
+
val_ds = ds.filter(lambda x: x["split"] == "val")
|
| 64 |
+
|
| 65 |
+
for required in ["embedding", "label"]:
|
| 66 |
+
if required not in train_ds.column_names:
|
| 67 |
+
raise ValueError(f"Missing column '{required}' in train split.")
|
| 68 |
+
if required not in val_ds.column_names:
|
| 69 |
+
raise ValueError(f"Missing column '{required}' in val split.")
|
| 70 |
+
|
| 71 |
+
X_train = _stack_embeddings(train_ds["embedding"]).astype(np.float32)
|
| 72 |
+
X_val = _stack_embeddings(val_ds["embedding"]).astype(np.float32)
|
| 73 |
+
|
| 74 |
+
y_train = np.asarray(train_ds["label"], dtype=np.float32)
|
| 75 |
+
y_val = np.asarray(val_ds["label"], dtype=np.float32)
|
| 76 |
+
|
| 77 |
+
seq_train = None
|
| 78 |
+
seq_val = None
|
| 79 |
+
if "sequence" in train_ds.column_names:
|
| 80 |
+
seq_train = np.asarray(train_ds["sequence"])
|
| 81 |
+
if "sequence" in val_ds.column_names:
|
| 82 |
+
seq_val = np.asarray(val_ds["sequence"])
|
| 83 |
+
|
| 84 |
+
return SplitData(X_train, y_train, seq_train, X_val, y_val, seq_val)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# -----------------------------
|
| 88 |
+
# Metrics
|
| 89 |
+
# -----------------------------
|
| 90 |
+
def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
| 91 |
+
rho = spearmanr(y_true, y_pred).correlation
|
| 92 |
+
if rho is None or np.isnan(rho):
|
| 93 |
+
return 0.0
|
| 94 |
+
return float(rho)
|
| 95 |
+
|
| 96 |
+
def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
|
| 97 |
+
# RMSE
|
| 98 |
+
try:
|
| 99 |
+
from sklearn.metrics import root_mean_squared_error
|
| 100 |
+
rmse = root_mean_squared_error(y_true, y_pred)
|
| 101 |
+
except Exception:
|
| 102 |
+
rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
|
| 103 |
+
|
| 104 |
+
mae = float(mean_absolute_error(y_true, y_pred))
|
| 105 |
+
r2 = float(r2_score(y_true, y_pred))
|
| 106 |
+
rho = float(safe_spearmanr(y_true, y_pred))
|
| 107 |
+
return {"rmse": rmse, "mae": mae, "r2": r2, "spearman_rho": rho}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# -----------------------------
|
| 111 |
+
# Model
|
| 112 |
+
# -----------------------------
|
| 113 |
+
def train_xgb_reg(
|
| 114 |
+
X_train, y_train, X_val, y_val, params: Dict[str, Any]
|
| 115 |
+
) -> Tuple[xgb.Booster, np.ndarray, np.ndarray]:
|
| 116 |
+
dtrain = xgb.DMatrix(X_train, label=y_train)
|
| 117 |
+
dval = xgb.DMatrix(X_val, label=y_val)
|
| 118 |
+
|
| 119 |
+
num_boost_round = int(params.pop("num_boost_round"))
|
| 120 |
+
early_stopping_rounds = int(params.pop("early_stopping_rounds"))
|
| 121 |
+
|
| 122 |
+
booster = xgb.train(
|
| 123 |
+
params=params,
|
| 124 |
+
dtrain=dtrain,
|
| 125 |
+
num_boost_round=num_boost_round,
|
| 126 |
+
evals=[(dval, "val")],
|
| 127 |
+
early_stopping_rounds=early_stopping_rounds,
|
| 128 |
+
verbose_eval=False,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
p_train = booster.predict(dtrain)
|
| 132 |
+
p_val = booster.predict(dval)
|
| 133 |
+
return booster, p_train, p_val
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def train_cuml_elasticnet_reg(
|
| 137 |
+
X_train, y_train, X_val, y_val, params: Dict[str, Any]
|
| 138 |
+
):
|
| 139 |
+
Xtr = to_gpu(X_train)
|
| 140 |
+
Xva = to_gpu(X_val)
|
| 141 |
+
ytr = to_gpu(y_train).astype(cp.float32)
|
| 142 |
+
|
| 143 |
+
model = cuElasticNet(
|
| 144 |
+
alpha=float(params["alpha"]),
|
| 145 |
+
l1_ratio=float(params["l1_ratio"]),
|
| 146 |
+
fit_intercept=True,
|
| 147 |
+
max_iter=int(params.get("max_iter", 5000)),
|
| 148 |
+
tol=float(params.get("tol", 1e-4)),
|
| 149 |
+
selection=params.get("selection", "cyclic"),
|
| 150 |
+
)
|
| 151 |
+
model.fit(Xtr, ytr)
|
| 152 |
+
|
| 153 |
+
p_train = to_cpu(model.predict(Xtr))
|
| 154 |
+
p_val = to_cpu(model.predict(Xva))
|
| 155 |
+
return model, p_train, p_val
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def train_svr_reg(
|
| 159 |
+
X_train, y_train, X_val, y_val, params: Dict[str, Any]
|
| 160 |
+
):
|
| 161 |
+
model = SVR(
|
| 162 |
+
C=float(params["C"]),
|
| 163 |
+
epsilon=float(params["epsilon"]),
|
| 164 |
+
kernel=params["kernel"],
|
| 165 |
+
gamma=params.get("gamma", "scale"),
|
| 166 |
+
)
|
| 167 |
+
model.fit(X_train, y_train)
|
| 168 |
+
p_train = model.predict(X_train)
|
| 169 |
+
p_val = model.predict(X_val)
|
| 170 |
+
return model, p_train, p_val
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# -----------------------------
|
| 174 |
+
# Saving + plots
|
| 175 |
+
# -----------------------------
|
| 176 |
+
def save_predictions_csv(
|
| 177 |
+
out_dir: str,
|
| 178 |
+
split_name: str,
|
| 179 |
+
y_true: np.ndarray,
|
| 180 |
+
y_pred: np.ndarray,
|
| 181 |
+
sequences: Optional[np.ndarray] = None,
|
| 182 |
+
):
|
| 183 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 184 |
+
df = pd.DataFrame({
|
| 185 |
+
"y_true": y_true.astype(float),
|
| 186 |
+
"y_pred": y_pred.astype(float),
|
| 187 |
+
"residual": (y_true - y_pred).astype(float),
|
| 188 |
+
})
|
| 189 |
+
if sequences is not None:
|
| 190 |
+
df.insert(0, "sequence", sequences)
|
| 191 |
+
df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
|
| 192 |
+
|
| 193 |
+
def plot_regression_diagnostics(out_dir: str, y_true: np.ndarray, y_pred: np.ndarray):
|
| 194 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 195 |
+
|
| 196 |
+
plt.figure()
|
| 197 |
+
plt.scatter(y_true, y_pred, s=8, alpha=0.5)
|
| 198 |
+
plt.xlabel("y_true")
|
| 199 |
+
plt.ylabel("y_pred")
|
| 200 |
+
plt.title("Predicted vs True")
|
| 201 |
+
plt.tight_layout()
|
| 202 |
+
plt.savefig(os.path.join(out_dir, "pred_vs_true.png"))
|
| 203 |
+
plt.close()
|
| 204 |
+
|
| 205 |
+
resid = y_true - y_pred
|
| 206 |
+
plt.figure()
|
| 207 |
+
plt.hist(resid, bins=50)
|
| 208 |
+
plt.xlabel("residual (y_true - y_pred)")
|
| 209 |
+
plt.ylabel("count")
|
| 210 |
+
plt.title("Residual Histogram")
|
| 211 |
+
plt.tight_layout()
|
| 212 |
+
plt.savefig(os.path.join(out_dir, "residual_hist.png"))
|
| 213 |
+
plt.close()
|
| 214 |
+
|
| 215 |
+
plt.figure()
|
| 216 |
+
plt.scatter(y_pred, resid, s=8, alpha=0.5)
|
| 217 |
+
plt.xlabel("y_pred")
|
| 218 |
+
plt.ylabel("residual")
|
| 219 |
+
plt.title("Residuals vs Prediction")
|
| 220 |
+
plt.tight_layout()
|
| 221 |
+
plt.savefig(os.path.join(out_dir, "residual_vs_pred.png"))
|
| 222 |
+
plt.close()
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# -----------------------------
|
| 226 |
+
# Optuna objective (OPTIMIZE SPEARMAN RHO)
|
| 227 |
+
# -----------------------------
|
| 228 |
+
def make_objective(model_name: str, data: SplitData):
|
| 229 |
+
Xtr, ytr, Xva, yva = data.X_train, data.y_train, data.X_val, data.y_val
|
| 230 |
+
|
| 231 |
+
def objective(trial: optuna.Trial) -> float:
|
| 232 |
+
if model_name == "xgb_reg":
|
| 233 |
+
params = {
|
| 234 |
+
"objective": "reg:squarederror",
|
| 235 |
+
"eval_metric": "rmse",
|
| 236 |
+
"lambda": trial.suggest_float("lambda", 1e-10, 100.0, log=True),
|
| 237 |
+
"alpha": trial.suggest_float("alpha", 1e-10, 100.0, log=True),
|
| 238 |
+
"gamma": trial.suggest_float("gamma", 0.0, 10.0),
|
| 239 |
+
"max_depth": trial.suggest_int("max_depth", 2, 16),
|
| 240 |
+
"min_child_weight": trial.suggest_float("min_child_weight", 1e-3, 500.0, log=True),
|
| 241 |
+
"subsample": trial.suggest_float("subsample", 0.5, 1.0),
|
| 242 |
+
"colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
|
| 243 |
+
"learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True),
|
| 244 |
+
"tree_method": "hist",
|
| 245 |
+
"device": "cuda",
|
| 246 |
+
}
|
| 247 |
+
params["num_boost_round"] = trial.suggest_int("num_boost_round", 50, 2000)
|
| 248 |
+
params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 20, 200)
|
| 249 |
+
|
| 250 |
+
model, p_tr, p_va = train_xgb_reg(Xtr, ytr, Xva, yva, params.copy())
|
| 251 |
+
|
| 252 |
+
elif model_name == "enet_gpu":
|
| 253 |
+
params = {
|
| 254 |
+
"alpha": trial.suggest_float("alpha", 1e-8, 10.0, log=True),
|
| 255 |
+
"l1_ratio": trial.suggest_float("l1_ratio", 0.0, 1.0),
|
| 256 |
+
"max_iter": trial.suggest_int("max_iter", 1000, 20000),
|
| 257 |
+
"tol": trial.suggest_float("tol", 1e-6, 1e-2, log=True),
|
| 258 |
+
"selection": trial.suggest_categorical("selection", ["cyclic", "random"]),
|
| 259 |
+
}
|
| 260 |
+
model, p_tr, p_va = train_cuml_elasticnet_reg(Xtr, ytr, Xva, yva, params)
|
| 261 |
+
|
| 262 |
+
elif model_name == "svr":
|
| 263 |
+
params = {
|
| 264 |
+
"kernel": trial.suggest_categorical("kernel", ["rbf", "linear", "poly", "sigmoid"]),
|
| 265 |
+
"C": trial.suggest_float("C", 1e-3, 1e3, log=True),
|
| 266 |
+
"epsilon": trial.suggest_float("epsilon", 1e-4, 1.0, log=True),
|
| 267 |
+
}
|
| 268 |
+
if params["kernel"] in ["rbf", "poly", "sigmoid"]:
|
| 269 |
+
params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
|
| 270 |
+
else:
|
| 271 |
+
params["gamma"] = "scale"
|
| 272 |
+
|
| 273 |
+
model, p_tr, p_va = train_svr_reg(Xtr, ytr, Xva, yva, params)
|
| 274 |
+
|
| 275 |
+
else:
|
| 276 |
+
raise ValueError(f"Unknown model_name={model_name}")
|
| 277 |
+
|
| 278 |
+
metrics = eval_regression(yva, p_va)
|
| 279 |
+
trial.set_user_attr("spearman_rho", metrics["spearman_rho"])
|
| 280 |
+
trial.set_user_attr("rmse", metrics["rmse"])
|
| 281 |
+
trial.set_user_attr("mae", metrics["mae"])
|
| 282 |
+
trial.set_user_attr("r2", metrics["r2"])
|
| 283 |
+
|
| 284 |
+
# OPTUNA OBJECTIVE = maximize Spearman rho
|
| 285 |
+
return metrics["spearman_rho"]
|
| 286 |
+
|
| 287 |
+
return objective
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# -----------------------------
|
| 291 |
+
# Main
|
| 292 |
+
# -----------------------------
|
| 293 |
+
def run_optuna_and_refit(
|
| 294 |
+
dataset_path: str,
|
| 295 |
+
out_dir: str,
|
| 296 |
+
model_name: str,
|
| 297 |
+
n_trials: int = 200,
|
| 298 |
+
standardize_X: bool = True,
|
| 299 |
+
):
|
| 300 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 301 |
+
|
| 302 |
+
data = load_split_data(dataset_path)
|
| 303 |
+
print(f"[Data] Train: {data.X_train.shape}, Val: {data.X_val.shape}")
|
| 304 |
+
|
| 305 |
+
# Standardize features (SVR + ElasticNet)
|
| 306 |
+
if standardize_X:
|
| 307 |
+
scaler = StandardScaler()
|
| 308 |
+
data.X_train = scaler.fit_transform(data.X_train).astype(np.float32)
|
| 309 |
+
data.X_val = scaler.transform(data.X_val).astype(np.float32)
|
| 310 |
+
joblib.dump(scaler, os.path.join(out_dir, "scaler.joblib"))
|
| 311 |
+
print("[Preprocess] Saved StandardScaler -> scaler.joblib")
|
| 312 |
+
|
| 313 |
+
study = optuna.create_study(
|
| 314 |
+
direction="maximize",
|
| 315 |
+
pruner=optuna.pruners.MedianPruner()
|
| 316 |
+
)
|
| 317 |
+
study.optimize(make_objective(model_name, data), n_trials=n_trials)
|
| 318 |
+
|
| 319 |
+
trials_df = study.trials_dataframe()
|
| 320 |
+
trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
|
| 321 |
+
|
| 322 |
+
best = study.best_trial
|
| 323 |
+
best_params = dict(best.params)
|
| 324 |
+
|
| 325 |
+
best_rho = float(best.user_attrs.get("spearman_rho", best.value))
|
| 326 |
+
best_rmse = float(best.user_attrs.get("rmse", np.nan))
|
| 327 |
+
best_mae = float(best.user_attrs.get("mae", np.nan))
|
| 328 |
+
best_r2 = float(best.user_attrs.get("r2", np.nan))
|
| 329 |
+
|
| 330 |
+
# Refit best model on train
|
| 331 |
+
if model_name == "xgb_reg":
|
| 332 |
+
params = {
|
| 333 |
+
"objective": "reg:squarederror",
|
| 334 |
+
"eval_metric": "rmse",
|
| 335 |
+
"lambda": best_params["lambda"],
|
| 336 |
+
"alpha": best_params["alpha"],
|
| 337 |
+
"gamma": best_params["gamma"],
|
| 338 |
+
"max_depth": best_params["max_depth"],
|
| 339 |
+
"min_child_weight": best_params["min_child_weight"],
|
| 340 |
+
"subsample": best_params["subsample"],
|
| 341 |
+
"colsample_bytree": best_params["colsample_bytree"],
|
| 342 |
+
"learning_rate": best_params["learning_rate"],
|
| 343 |
+
"tree_method": "hist",
|
| 344 |
+
"device": "cuda",
|
| 345 |
+
"num_boost_round": best_params["num_boost_round"],
|
| 346 |
+
"early_stopping_rounds": best_params["early_stopping_rounds"],
|
| 347 |
+
}
|
| 348 |
+
model, p_tr, p_va = train_xgb_reg(
|
| 349 |
+
data.X_train, data.y_train, data.X_val, data.y_val, params
|
| 350 |
+
)
|
| 351 |
+
model_path = os.path.join(out_dir, "best_model.json")
|
| 352 |
+
model.save_model(model_path)
|
| 353 |
+
|
| 354 |
+
elif model_name == "enet_gpu":
|
| 355 |
+
model, p_tr, p_va = train_cuml_elasticnet_reg(
|
| 356 |
+
data.X_train, data.y_train, data.X_val, data.y_val, best_params
|
| 357 |
+
)
|
| 358 |
+
model_path = os.path.join(out_dir, "best_model_cuml_enet.joblib")
|
| 359 |
+
joblib.dump(model, model_path)
|
| 360 |
+
|
| 361 |
+
elif model_name == "svr":
|
| 362 |
+
model, p_tr, p_va = train_svr_reg(
|
| 363 |
+
data.X_train, data.y_train, data.X_val, data.y_val, best_params
|
| 364 |
+
)
|
| 365 |
+
model_path = os.path.join(out_dir, "best_model_svr.joblib")
|
| 366 |
+
joblib.dump(model, model_path)
|
| 367 |
+
|
| 368 |
+
else:
|
| 369 |
+
raise ValueError(model_name)
|
| 370 |
+
|
| 371 |
+
save_predictions_csv(out_dir, "train", data.y_train, p_tr, data.seq_train)
|
| 372 |
+
save_predictions_csv(out_dir, "val", data.y_val, p_va, data.seq_val)
|
| 373 |
+
|
| 374 |
+
plot_regression_diagnostics(out_dir, data.y_val, p_va)
|
| 375 |
+
|
| 376 |
+
summary = [
|
| 377 |
+
"=" * 72,
|
| 378 |
+
f"MODEL: {model_name}",
|
| 379 |
+
f"Best trial: {best.number}",
|
| 380 |
+
f"Val Spearman rho (objective): {best_rho:.6f}",
|
| 381 |
+
f"Val RMSE: {best_rmse:.6f}",
|
| 382 |
+
f"Val MAE: {best_mae:.6f}",
|
| 383 |
+
f"Val R2: {best_r2:.6f}",
|
| 384 |
+
f"Model saved to: {model_path}",
|
| 385 |
+
"Best params:",
|
| 386 |
+
json.dumps(best_params, indent=2),
|
| 387 |
+
"=" * 72,
|
| 388 |
+
]
|
| 389 |
+
with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
|
| 390 |
+
f.write("\n".join(summary))
|
| 391 |
+
print("\n".join(summary))
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
if __name__ == "__main__":
|
| 395 |
+
import argparse
|
| 396 |
+
parser = argparse.ArgumentParser()
|
| 397 |
+
parser.add_argument("--dataset_path", type=str, required=True)
|
| 398 |
+
parser.add_argument("--out_dir", type=str, required=True)
|
| 399 |
+
parser.add_argument("--model", type=str, choices=["xgb_reg", "enet_gpu", "svr"], required=True)
|
| 400 |
+
parser.add_argument("--n_trials", type=int, default=200)
|
| 401 |
+
parser.add_argument("--no_standardize", action="store_true", help="Disable StandardScaler on X")
|
| 402 |
+
args = parser.parse_args()
|
| 403 |
+
|
| 404 |
+
run_optuna_and_refit(
|
| 405 |
+
dataset_path=args.dataset_path,
|
| 406 |
+
out_dir=args.out_dir,
|
| 407 |
+
model_name=args.model,
|
| 408 |
+
n_trials=args.n_trials,
|
| 409 |
+
standardize_X=(not args.no_standardize),
|
| 410 |
+
)
|
training_classifiers/.ipynb_checkpoints/train_nn-checkpoint.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from datasets import load_from_disk, DatasetDict
|
| 5 |
+
from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import optuna
|
| 8 |
+
import os
|
| 9 |
+
from typing import Dict, Any, Tuple, Optional
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
from sklearn.metrics import (
|
| 12 |
+
f1_score, roc_auc_score, average_precision_score,
|
| 13 |
+
precision_recall_curve, roc_curve
|
| 14 |
+
)
|
| 15 |
+
import json
|
| 16 |
+
import joblib
|
| 17 |
+
import pandas as pd
|
| 18 |
+
import time
|
| 19 |
+
|
| 20 |
+
def infer_in_dim_from_unpooled_ds(ds) -> int:
|
| 21 |
+
ex = ds[0]
|
| 22 |
+
# ex["embedding"] is (L, H) list/array
|
| 23 |
+
return int(len(ex["embedding"][0]))
|
| 24 |
+
|
| 25 |
+
def load_split(dataset_path):
|
| 26 |
+
ds = load_from_disk(dataset_path)
|
| 27 |
+
|
| 28 |
+
if isinstance(ds, DatasetDict):
|
| 29 |
+
return ds["train"], ds["val"]
|
| 30 |
+
|
| 31 |
+
raise ValueError("Expected DatasetDict with 'train' and 'val' splits")
|
| 32 |
+
|
| 33 |
+
def collate_unpooled(batch):
|
| 34 |
+
# batch: list of dicts
|
| 35 |
+
lengths = [int(x["length"]) for x in batch]
|
| 36 |
+
Lmax = max(lengths)
|
| 37 |
+
H = len(batch[0]["embedding"][0]) # 1280
|
| 38 |
+
|
| 39 |
+
X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32)
|
| 40 |
+
M = torch.zeros(len(batch), Lmax, dtype=torch.bool)
|
| 41 |
+
y = torch.tensor([x["label"] for x in batch], dtype=torch.float32)
|
| 42 |
+
|
| 43 |
+
for i, x in enumerate(batch):
|
| 44 |
+
emb = torch.tensor(x["embedding"], dtype=torch.float32) # (L, H)
|
| 45 |
+
L = emb.shape[0]
|
| 46 |
+
X[i, :L] = emb
|
| 47 |
+
if "attention_mask" in x:
|
| 48 |
+
m = torch.tensor(x["attention_mask"], dtype=torch.bool)
|
| 49 |
+
M[i, :L] = m[:L]
|
| 50 |
+
else:
|
| 51 |
+
M[i, :L] = True
|
| 52 |
+
|
| 53 |
+
return X, M, y
|
| 54 |
+
|
| 55 |
+
# ======================== Helper functions =========================================
|
| 56 |
+
def save_predictions_csv(
|
| 57 |
+
out_dir: str,
|
| 58 |
+
split_name: str,
|
| 59 |
+
y_true: np.ndarray,
|
| 60 |
+
y_prob: np.ndarray,
|
| 61 |
+
threshold: float,
|
| 62 |
+
sequences: Optional[np.ndarray] = None,
|
| 63 |
+
):
|
| 64 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 65 |
+
df = pd.DataFrame({
|
| 66 |
+
"y_true": y_true.astype(int),
|
| 67 |
+
"y_prob": y_prob.astype(float),
|
| 68 |
+
"y_pred": (y_prob >= threshold).astype(int),
|
| 69 |
+
})
|
| 70 |
+
if sequences is not None:
|
| 71 |
+
df.insert(0, "sequence", sequences)
|
| 72 |
+
df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def plot_curves(out_dir: str, y_true: np.ndarray, y_prob: np.ndarray):
|
| 76 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 77 |
+
|
| 78 |
+
# PR
|
| 79 |
+
precision, recall, _ = precision_recall_curve(y_true, y_prob)
|
| 80 |
+
plt.figure()
|
| 81 |
+
plt.plot(recall, precision)
|
| 82 |
+
plt.xlabel("Recall")
|
| 83 |
+
plt.ylabel("Precision")
|
| 84 |
+
plt.title("Precision-Recall Curve")
|
| 85 |
+
plt.tight_layout()
|
| 86 |
+
plt.savefig(os.path.join(out_dir, "pr_curve.png"))
|
| 87 |
+
plt.close()
|
| 88 |
+
|
| 89 |
+
# ROC
|
| 90 |
+
fpr, tpr, _ = roc_curve(y_true, y_prob)
|
| 91 |
+
plt.figure()
|
| 92 |
+
plt.plot(fpr, tpr)
|
| 93 |
+
plt.xlabel("False Positive Rate")
|
| 94 |
+
plt.ylabel("True Positive Rate")
|
| 95 |
+
plt.title("ROC Curve")
|
| 96 |
+
plt.tight_layout()
|
| 97 |
+
plt.savefig(os.path.join(out_dir, "roc_curve.png"))
|
| 98 |
+
plt.close()
|
| 99 |
+
|
| 100 |
+
# ======================== Shared OPTUNA training scheme =========================================
|
| 101 |
+
def best_f1_threshold(y_true, y_prob):
|
| 102 |
+
p, r, thr = precision_recall_curve(y_true, y_prob)
|
| 103 |
+
f1s = (2*p[:-1]*r[:-1])/(p[:-1]+r[:-1]+1e-12)
|
| 104 |
+
i = int(np.nanargmax(f1s))
|
| 105 |
+
return float(thr[i]), float(f1s[i])
|
| 106 |
+
|
| 107 |
+
@torch.no_grad()
|
| 108 |
+
def eval_probs(model, loader, device):
|
| 109 |
+
model.eval()
|
| 110 |
+
ys, ps = [], []
|
| 111 |
+
for X, M, y in loader:
|
| 112 |
+
X, M = X.to(device), M.to(device)
|
| 113 |
+
logits = model(X, M)
|
| 114 |
+
prob = torch.sigmoid(logits).detach().cpu().numpy()
|
| 115 |
+
ys.append(y.numpy())
|
| 116 |
+
ps.append(prob)
|
| 117 |
+
return np.concatenate(ys), np.concatenate(ps)
|
| 118 |
+
|
| 119 |
+
def train_one_epoch(model, loader, optim, criterion, device):
|
| 120 |
+
model.train()
|
| 121 |
+
for X, M, y in loader:
|
| 122 |
+
X, M, y = X.to(device), M.to(device), y.to(device)
|
| 123 |
+
optim.zero_grad(set_to_none=True)
|
| 124 |
+
logits = model(X, M)
|
| 125 |
+
loss = criterion(logits, y)
|
| 126 |
+
loss.backward()
|
| 127 |
+
optim.step()
|
| 128 |
+
|
| 129 |
+
# ======================== MLP =========================================
|
| 130 |
+
# Still need mean pooling along lengths
|
| 131 |
+
class MaskedMeanPool(nn.Module):
|
| 132 |
+
def forward(self, X, M): # X: (B,L,H), M: (B,L)
|
| 133 |
+
Mf = M.unsqueeze(-1).float()
|
| 134 |
+
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 135 |
+
return (X * Mf).sum(dim=1) / denom # (B,H)
|
| 136 |
+
|
| 137 |
+
class MLPClassifier(nn.Module):
|
| 138 |
+
def __init__(self, in_dim, hidden=512, dropout=0.1):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.pool = MaskedMeanPool()
|
| 141 |
+
self.net = nn.Sequential(
|
| 142 |
+
nn.Linear(in_dim, hidden),
|
| 143 |
+
nn.GELU(),
|
| 144 |
+
nn.Dropout(dropout),
|
| 145 |
+
nn.Linear(hidden, 1),
|
| 146 |
+
)
|
| 147 |
+
def forward(self, X, M):
|
| 148 |
+
z = self.pool(X, M)
|
| 149 |
+
return self.net(z).squeeze(-1) # logits
|
| 150 |
+
|
| 151 |
+
# ======================== CNN =========================================
|
| 152 |
+
# Treat 1280 dimensions as channels
|
| 153 |
+
class CNNClassifier(nn.Module):
|
| 154 |
+
def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
|
| 155 |
+
super().__init__()
|
| 156 |
+
blocks = []
|
| 157 |
+
ch = in_ch
|
| 158 |
+
for _ in range(layers):
|
| 159 |
+
blocks += [
|
| 160 |
+
nn.Conv1d(ch, c, kernel_size=k, padding=k//2),
|
| 161 |
+
nn.GELU(),
|
| 162 |
+
nn.Dropout(dropout),
|
| 163 |
+
]
|
| 164 |
+
ch = c
|
| 165 |
+
self.conv = nn.Sequential(*blocks)
|
| 166 |
+
self.head = nn.Linear(c, 1)
|
| 167 |
+
|
| 168 |
+
def forward(self, X, M):
|
| 169 |
+
# X: (B,L,H) -> (B,H,L)
|
| 170 |
+
Xc = X.transpose(1, 2)
|
| 171 |
+
Y = self.conv(Xc).transpose(1, 2) # (B,L,C)
|
| 172 |
+
|
| 173 |
+
# masked mean pool over L
|
| 174 |
+
Mf = M.unsqueeze(-1).float()
|
| 175 |
+
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 176 |
+
pooled = (Y * Mf).sum(dim=1) / denom # (B,C)
|
| 177 |
+
return self.head(pooled).squeeze(-1)
|
| 178 |
+
|
| 179 |
+
# ========================== Transformer ====================================
|
| 180 |
+
class TransformerClassifier(nn.Module):
|
| 181 |
+
def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.proj = nn.Linear(in_dim, d_model)
|
| 184 |
+
enc_layer = nn.TransformerEncoderLayer(
|
| 185 |
+
d_model=d_model, nhead=nhead, dim_feedforward=ff,
|
| 186 |
+
dropout=dropout, batch_first=True, activation="gelu"
|
| 187 |
+
)
|
| 188 |
+
self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
|
| 189 |
+
self.head = nn.Linear(d_model, 1)
|
| 190 |
+
|
| 191 |
+
def forward(self, X, M):
|
| 192 |
+
# src_key_padding_mask: True = pad positions
|
| 193 |
+
pad_mask = ~M
|
| 194 |
+
Z = self.proj(X) # (B,L,d)
|
| 195 |
+
Z = self.enc(Z, src_key_padding_mask=pad_mask) # (B,L,d)
|
| 196 |
+
|
| 197 |
+
Mf = M.unsqueeze(-1).float()
|
| 198 |
+
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 199 |
+
pooled = (Z * Mf).sum(dim=1) / denom
|
| 200 |
+
return self.head(pooled).squeeze(-1)
|
| 201 |
+
|
| 202 |
+
# ========================== OPTUNA ====================================
|
| 203 |
+
|
| 204 |
+
def objective_nn(trial, model_name, train_ds, val_ds, device="cuda:0"):
|
| 205 |
+
# hyperparams shared
|
| 206 |
+
lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
|
| 207 |
+
wd = trial.suggest_float("weight_decay", 1e-8, 1e-2, log=True)
|
| 208 |
+
dropout = trial.suggest_float("dropout", 0.0, 0.5)
|
| 209 |
+
batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
|
| 210 |
+
|
| 211 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
|
| 212 |
+
collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
|
| 213 |
+
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
|
| 214 |
+
collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
|
| 215 |
+
|
| 216 |
+
in_dim = infer_in_dim_from_unpooled_ds(train_ds)
|
| 217 |
+
|
| 218 |
+
if model_name == "mlp":
|
| 219 |
+
hidden = trial.suggest_categorical("hidden", [256, 512, 1024, 2048])
|
| 220 |
+
model = MLPClassifier(in_dim=in_dim, hidden=hidden, dropout=dropout)
|
| 221 |
+
elif model_name == "cnn":
|
| 222 |
+
c = trial.suggest_categorical("channels", [128, 256, 512])
|
| 223 |
+
k = trial.suggest_categorical("kernel", [3, 5, 7])
|
| 224 |
+
layers = trial.suggest_int("layers", 1, 4)
|
| 225 |
+
model = CNNClassifier(in_ch=in_dim, c=c, k=k, layers=layers, dropout=dropout)
|
| 226 |
+
elif model_name == "transformer":
|
| 227 |
+
d = trial.suggest_categorical("d_model", [128, 256, 384])
|
| 228 |
+
nhead = trial.suggest_categorical("nhead", [4, 8])
|
| 229 |
+
layers = trial.suggest_int("layers", 1, 4)
|
| 230 |
+
ff = trial.suggest_categorical("ff", [256, 512, 1024, 1536])
|
| 231 |
+
model = TransformerClassifier(in_dim=in_dim, d_model=d, nhead=nhead, layers=layers, ff=ff, dropout=dropout)
|
| 232 |
+
else:
|
| 233 |
+
raise ValueError(model_name)
|
| 234 |
+
|
| 235 |
+
model = model.to(device)
|
| 236 |
+
|
| 237 |
+
# class imbalance handling
|
| 238 |
+
ytr = np.asarray(train_ds["label"], dtype=np.int64)
|
| 239 |
+
pos = ytr.sum()
|
| 240 |
+
neg = len(ytr) - pos
|
| 241 |
+
pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32)
|
| 242 |
+
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
| 243 |
+
|
| 244 |
+
optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
|
| 245 |
+
|
| 246 |
+
best_f1 = -1.0
|
| 247 |
+
patience = 8
|
| 248 |
+
bad = 0
|
| 249 |
+
|
| 250 |
+
for epoch in range(1, 51):
|
| 251 |
+
train_one_epoch(model, train_loader, optim, criterion, device)
|
| 252 |
+
|
| 253 |
+
y_true, y_prob = eval_probs(model, val_loader, device)
|
| 254 |
+
auc = roc_auc_score(y_true, y_prob)
|
| 255 |
+
|
| 256 |
+
thr, f1 = best_f1_threshold(y_true, y_prob)
|
| 257 |
+
|
| 258 |
+
trial.set_user_attr("val_auc", float(auc))
|
| 259 |
+
trial.set_user_attr("val_f1", float(f1))
|
| 260 |
+
trial.set_user_attr("val_thr", float(thr))
|
| 261 |
+
|
| 262 |
+
# prune
|
| 263 |
+
trial.report(f1, epoch)
|
| 264 |
+
if trial.should_prune():
|
| 265 |
+
raise optuna.TrialPruned()
|
| 266 |
+
|
| 267 |
+
if f1 > best_f1 + 1e-4:
|
| 268 |
+
best_f1 = f1
|
| 269 |
+
bad = 0
|
| 270 |
+
else:
|
| 271 |
+
bad += 1
|
| 272 |
+
if bad >= patience:
|
| 273 |
+
break
|
| 274 |
+
|
| 275 |
+
return best_f1
|
| 276 |
+
|
| 277 |
+
def run_optuna_and_refit_nn(dataset_path: str, out_dir: str, model_name: str, n_trials: int = 50, device="cuda:0"):
|
| 278 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 279 |
+
|
| 280 |
+
train_ds, val_ds = load_split(dataset_path)
|
| 281 |
+
print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}")
|
| 282 |
+
|
| 283 |
+
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
|
| 284 |
+
study.optimize(lambda trial: objective_nn(trial, model_name, train_ds, val_ds, device=device), n_trials=n_trials)
|
| 285 |
+
|
| 286 |
+
trials_df = study.trials_dataframe()
|
| 287 |
+
trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
|
| 288 |
+
|
| 289 |
+
best = study.best_trial
|
| 290 |
+
best_params = dict(best.params)
|
| 291 |
+
best_f1_optuna = float(best.value)
|
| 292 |
+
best_auc_optuna = float(best.user_attrs.get("val_auc", np.nan))
|
| 293 |
+
best_thr = float(best.user_attrs.get("val_thr", 0.5))
|
| 294 |
+
|
| 295 |
+
in_dim = infer_in_dim_from_unpooled_ds(train_ds)
|
| 296 |
+
|
| 297 |
+
# --- Refit best model ---
|
| 298 |
+
batch_size = int(best_params.get("batch_size", 32))
|
| 299 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
|
| 300 |
+
collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
|
| 301 |
+
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
|
| 302 |
+
collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
|
| 303 |
+
|
| 304 |
+
# Rebuild
|
| 305 |
+
dropout = float(best_params.get("dropout", 0.1))
|
| 306 |
+
if model_name == "mlp":
|
| 307 |
+
model = MLPClassifier(
|
| 308 |
+
in_dim=in_dim,
|
| 309 |
+
hidden=int(best_params["hidden"]),
|
| 310 |
+
dropout=dropout,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
elif model_name == "cnn":
|
| 314 |
+
model = CNNClassifier(
|
| 315 |
+
in_ch=in_dim,
|
| 316 |
+
c=int(best_params["channels"]),
|
| 317 |
+
k=int(best_params["kernel"]),
|
| 318 |
+
layers=int(best_params["layers"]),
|
| 319 |
+
dropout=dropout,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
elif model_name == "transformer":
|
| 323 |
+
model = TransformerClassifier(
|
| 324 |
+
in_dim=in_dim,
|
| 325 |
+
d_model=int(best_params["d_model"]),
|
| 326 |
+
nhead=int(best_params["nhead"]),
|
| 327 |
+
layers=int(best_params["layers"]),
|
| 328 |
+
ff=int(best_params["ff"]),
|
| 329 |
+
dropout=dropout,
|
| 330 |
+
)
|
| 331 |
+
else:
|
| 332 |
+
raise ValueError(model_name)
|
| 333 |
+
|
| 334 |
+
model = model.to(device)
|
| 335 |
+
|
| 336 |
+
# loss + optimizer
|
| 337 |
+
ytr = np.asarray(train_ds["label"], dtype=np.int64)
|
| 338 |
+
pos = ytr.sum()
|
| 339 |
+
neg = len(ytr) - pos
|
| 340 |
+
pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32)
|
| 341 |
+
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
| 342 |
+
|
| 343 |
+
lr = float(best_params["lr"])
|
| 344 |
+
wd = float(best_params["weight_decay"])
|
| 345 |
+
optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
|
| 346 |
+
|
| 347 |
+
# train longer with early stopping on AUC
|
| 348 |
+
best_f1_seen, bad, patience = -1.0, 0, 12
|
| 349 |
+
best_state = None
|
| 350 |
+
best_thr_seen = 0.5
|
| 351 |
+
best_auc_seen = -1.0
|
| 352 |
+
|
| 353 |
+
for epoch in range(1, 151):
|
| 354 |
+
train_one_epoch(model, train_loader, optim, criterion, device)
|
| 355 |
+
|
| 356 |
+
y_true, y_prob = eval_probs(model, val_loader, device)
|
| 357 |
+
auc = roc_auc_score(y_true, y_prob)
|
| 358 |
+
thr, f1 = best_f1_threshold(y_true, y_prob)
|
| 359 |
+
|
| 360 |
+
if f1 > best_f1_seen + 1e-4:
|
| 361 |
+
best_f1_seen = f1
|
| 362 |
+
best_thr_seen = thr
|
| 363 |
+
best_auc_seen = auc
|
| 364 |
+
bad = 0
|
| 365 |
+
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
| 366 |
+
else:
|
| 367 |
+
bad += 1
|
| 368 |
+
if bad >= patience:
|
| 369 |
+
break
|
| 370 |
+
|
| 371 |
+
if best_state is not None:
|
| 372 |
+
model.load_state_dict(best_state)
|
| 373 |
+
|
| 374 |
+
# final preds + threshold picked on val
|
| 375 |
+
y_true_val, y_prob_val = eval_probs(model, val_loader, device)
|
| 376 |
+
best_thr_final, best_f1_final = best_f1_threshold(y_true_val, y_prob_val)
|
| 377 |
+
|
| 378 |
+
# save model
|
| 379 |
+
model_path = os.path.join(out_dir, "best_model.pt")
|
| 380 |
+
torch.save({"state_dict": model.state_dict(), "best_params": best_params}, model_path)
|
| 381 |
+
|
| 382 |
+
# train preds
|
| 383 |
+
y_true_tr, y_prob_tr = eval_probs(model, DataLoader(train_ds, batch_size=64, shuffle=False,
|
| 384 |
+
collate_fn=collate_unpooled, num_workers=4, pin_memory=True), device)
|
| 385 |
+
|
| 386 |
+
save_predictions_csv(out_dir, "train", y_true_tr, y_prob_tr, best_thr_final,
|
| 387 |
+
sequences=np.asarray(train_ds["sequence"]) if "sequence" in train_ds.column_names else None)
|
| 388 |
+
save_predictions_csv(out_dir, "val", y_true_val, y_prob_val, best_thr_final,
|
| 389 |
+
sequences=np.asarray(val_ds["sequence"]) if "sequence" in val_ds.column_names else None)
|
| 390 |
+
|
| 391 |
+
plot_curves(out_dir, y_true_val, y_prob_val)
|
| 392 |
+
|
| 393 |
+
summary = [
|
| 394 |
+
"=" * 72,
|
| 395 |
+
f"MODEL: {model_name}",
|
| 396 |
+
|
| 397 |
+
# Optuna results (objective = F1)
|
| 398 |
+
f"Best Optuna F1 (objective): {best_f1_optuna:.4f}",
|
| 399 |
+
f"Best Optuna AUC (val, recorded): {best_auc_optuna:.4f}",
|
| 400 |
+
f"Best Optuna threshold (val): {best_thr:.4f}",
|
| 401 |
+
|
| 402 |
+
# Refit results
|
| 403 |
+
f"Refit best AUC (val): {best_auc_seen:.4f}",
|
| 404 |
+
f"Refit best F1@thr (val): {best_f1_final:.4f} at thr={best_thr_final:.4f}",
|
| 405 |
+
|
| 406 |
+
"Best params:",
|
| 407 |
+
json.dumps(best_params, indent=2),
|
| 408 |
+
f"Saved model: {model_path}",
|
| 409 |
+
"=" * 72,
|
| 410 |
+
]
|
| 411 |
+
|
| 412 |
+
with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
|
| 413 |
+
f.write("\n".join(summary))
|
| 414 |
+
print("\n".join(summary))
|
| 415 |
+
|
| 416 |
+
if __name__ == "__main__":
|
| 417 |
+
import argparse
|
| 418 |
+
parser = argparse.ArgumentParser()
|
| 419 |
+
parser.add_argument("--dataset_path", type=str, required=True)
|
| 420 |
+
parser.add_argument("--out_dir", type=str, required=True)
|
| 421 |
+
parser.add_argument("--model", type=str, choices=["mlp", "cnn", "transformer"], required=True)
|
| 422 |
+
parser.add_argument("--n_trials", type=int, default=50)
|
| 423 |
+
args = parser.parse_args()
|
| 424 |
+
|
| 425 |
+
if args.model in ["mlp", "cnn", "transformer"]:
|
| 426 |
+
run_optuna_and_refit_nn(args.dataset_path, args.out_dir, args.model, args.n_trials, device="cuda:0")
|
training_classifiers/.ipynb_checkpoints/train_nn_regression-checkpoint.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, json, time
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from datasets import load_from_disk, DatasetDict
|
| 10 |
+
import optuna
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Dict, Any, Tuple, Optional
|
| 13 |
+
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
| 14 |
+
from scipy.stats import spearmanr
|
| 15 |
+
from torch.cuda.amp import autocast
|
| 16 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 17 |
+
scaler = GradScaler(enabled=torch.cuda.is_available())
|
| 18 |
+
from lightning.pytorch import seed_everything
|
| 19 |
+
seed_everything(1986)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_split(dataset_path):
|
| 23 |
+
ds = load_from_disk(dataset_path)
|
| 24 |
+
if isinstance(ds, DatasetDict):
|
| 25 |
+
return ds["train"], ds["val"]
|
| 26 |
+
raise ValueError("Expected DatasetDict with 'train' and 'val' splits")
|
| 27 |
+
|
| 28 |
+
def collate_unpooled_reg(batch):
|
| 29 |
+
lengths = [int(x["length"]) for x in batch]
|
| 30 |
+
Lmax = max(lengths)
|
| 31 |
+
H = len(batch[0]["embedding"][0])
|
| 32 |
+
|
| 33 |
+
X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32)
|
| 34 |
+
M = torch.zeros(len(batch), Lmax, dtype=torch.bool)
|
| 35 |
+
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
|
| 36 |
+
|
| 37 |
+
for i, x in enumerate(batch):
|
| 38 |
+
emb = torch.tensor(x["embedding"], dtype=torch.float32) # (L,H)
|
| 39 |
+
L = emb.shape[0]
|
| 40 |
+
X[i, :L] = emb
|
| 41 |
+
if "attention_mask" in x:
|
| 42 |
+
m = torch.tensor(x["attention_mask"], dtype=torch.bool)
|
| 43 |
+
M[i, :L] = m[:L]
|
| 44 |
+
else:
|
| 45 |
+
M[i, :L] = True
|
| 46 |
+
return X, M, y
|
| 47 |
+
|
| 48 |
+
def infer_in_dim(ds) -> int:
|
| 49 |
+
ex = ds[0]
|
| 50 |
+
return int(len(ex["embedding"][0]))
|
| 51 |
+
|
| 52 |
+
# ============================
|
| 53 |
+
# Metrics
|
| 54 |
+
# ============================
|
| 55 |
+
def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
| 56 |
+
rho = spearmanr(y_true, y_pred).correlation
|
| 57 |
+
if rho is None or np.isnan(rho):
|
| 58 |
+
return 0.0
|
| 59 |
+
return float(rho)
|
| 60 |
+
|
| 61 |
+
def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
|
| 62 |
+
# ---- RMSE ----
|
| 63 |
+
try:
|
| 64 |
+
from sklearn.metrics import root_mean_squared_error
|
| 65 |
+
rmse = root_mean_squared_error(y_true, y_pred)
|
| 66 |
+
except Exception:
|
| 67 |
+
mse = mean_squared_error(y_true, y_pred)
|
| 68 |
+
rmse = float(np.sqrt(mse))
|
| 69 |
+
|
| 70 |
+
mae = float(mean_absolute_error(y_true, y_pred))
|
| 71 |
+
r2 = float(r2_score(y_true, y_pred))
|
| 72 |
+
rho = float(safe_spearmanr(y_true, y_pred))
|
| 73 |
+
return {"rmse": float(rmse), "mae": mae, "r2": r2, "spearman_rho": rho}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ============================
|
| 77 |
+
# Models
|
| 78 |
+
# ============================
|
| 79 |
+
class MaskedMeanPool(nn.Module):
|
| 80 |
+
def forward(self, X, M):
|
| 81 |
+
Mf = M.unsqueeze(-1).float()
|
| 82 |
+
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 83 |
+
return (X * Mf).sum(dim=1) / denom
|
| 84 |
+
|
| 85 |
+
class MLPRegressor(nn.Module):
|
| 86 |
+
def __init__(self, in_dim, hidden=512, dropout=0.1):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.pool = MaskedMeanPool()
|
| 89 |
+
self.net = nn.Sequential(
|
| 90 |
+
nn.Linear(in_dim, hidden),
|
| 91 |
+
nn.GELU(),
|
| 92 |
+
nn.Dropout(dropout),
|
| 93 |
+
nn.Linear(hidden, 1),
|
| 94 |
+
)
|
| 95 |
+
def forward(self, X, M):
|
| 96 |
+
z = self.pool(X, M)
|
| 97 |
+
return self.net(z).squeeze(-1) # y_pred
|
| 98 |
+
|
| 99 |
+
class CNNRegressor(nn.Module):
|
| 100 |
+
def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
|
| 101 |
+
super().__init__()
|
| 102 |
+
blocks = []
|
| 103 |
+
ch = in_ch
|
| 104 |
+
for _ in range(layers):
|
| 105 |
+
blocks += [
|
| 106 |
+
nn.Conv1d(ch, c, kernel_size=k, padding=k//2),
|
| 107 |
+
nn.GELU(),
|
| 108 |
+
nn.Dropout(dropout),
|
| 109 |
+
]
|
| 110 |
+
ch = c
|
| 111 |
+
self.conv = nn.Sequential(*blocks)
|
| 112 |
+
self.head = nn.Linear(c, 1)
|
| 113 |
+
|
| 114 |
+
def forward(self, X, M):
|
| 115 |
+
Xc = X.transpose(1, 2) # (B,H,L)
|
| 116 |
+
Y = self.conv(Xc).transpose(1, 2) # (B,L,C)
|
| 117 |
+
Mf = M.unsqueeze(-1).float()
|
| 118 |
+
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 119 |
+
pooled = (Y * Mf).sum(dim=1) / denom # (B,C)
|
| 120 |
+
return self.head(pooled).squeeze(-1)
|
| 121 |
+
|
| 122 |
+
class TransformerRegressor(nn.Module):
|
| 123 |
+
def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.proj = nn.Linear(in_dim, d_model)
|
| 126 |
+
enc_layer = nn.TransformerEncoderLayer(
|
| 127 |
+
d_model=d_model, nhead=nhead, dim_feedforward=ff,
|
| 128 |
+
dropout=dropout, batch_first=True, activation="gelu"
|
| 129 |
+
)
|
| 130 |
+
self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
|
| 131 |
+
self.head = nn.Linear(d_model, 1)
|
| 132 |
+
|
| 133 |
+
def forward(self, X, M):
|
| 134 |
+
pad_mask = ~M
|
| 135 |
+
Z = self.proj(X)
|
| 136 |
+
Z = self.enc(Z, src_key_padding_mask=pad_mask)
|
| 137 |
+
Mf = M.unsqueeze(-1).float()
|
| 138 |
+
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 139 |
+
pooled = (Z * Mf).sum(dim=1) / denom
|
| 140 |
+
return self.head(pooled).squeeze(-1)
|
| 141 |
+
|
| 142 |
+
# ============================
|
| 143 |
+
# Train / eval
|
| 144 |
+
# ============================
|
| 145 |
+
@torch.no_grad()
|
| 146 |
+
def eval_preds(model, loader, device):
|
| 147 |
+
model.eval()
|
| 148 |
+
ys, ps = [], []
|
| 149 |
+
for X, M, y in loader:
|
| 150 |
+
X, M = X.to(device), M.to(device)
|
| 151 |
+
pred = model(X, M).detach().cpu().numpy()
|
| 152 |
+
ys.append(y.numpy())
|
| 153 |
+
ps.append(pred)
|
| 154 |
+
return np.concatenate(ys), np.concatenate(ps)
|
| 155 |
+
|
| 156 |
+
def train_one_epoch_reg(model, loader, optim, criterion, device):
|
| 157 |
+
model.train()
|
| 158 |
+
for X, M, y in loader:
|
| 159 |
+
X, M, y = X.to(device), M.to(device), y.to(device)
|
| 160 |
+
optim.zero_grad(set_to_none=True)
|
| 161 |
+
with autocast(enabled=torch.cuda.is_available()):
|
| 162 |
+
pred = model(X, M)
|
| 163 |
+
loss = criterion(pred, y)
|
| 164 |
+
scaler.scale(loss).backward()
|
| 165 |
+
scaler.step(optim)
|
| 166 |
+
scaler.update()
|
| 167 |
+
|
| 168 |
+
# ============================
|
| 169 |
+
# Saving + plots
|
| 170 |
+
# ============================
|
| 171 |
+
def save_predictions_csv(out_dir, split_name, y_true, y_pred, sequences=None):
|
| 172 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 173 |
+
df = pd.DataFrame({
|
| 174 |
+
"y_true": y_true.astype(float),
|
| 175 |
+
"y_pred": y_pred.astype(float),
|
| 176 |
+
"residual": (y_true - y_pred).astype(float),
|
| 177 |
+
})
|
| 178 |
+
if sequences is not None:
|
| 179 |
+
df.insert(0, "sequence", sequences)
|
| 180 |
+
df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
|
| 181 |
+
|
| 182 |
+
def plot_regression_diagnostics(out_dir, y_true, y_pred):
|
| 183 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 184 |
+
|
| 185 |
+
plt.figure()
|
| 186 |
+
plt.scatter(y_true, y_pred, s=8, alpha=0.5)
|
| 187 |
+
plt.xlabel("y_true"); plt.ylabel("y_pred")
|
| 188 |
+
plt.title("Predicted vs True")
|
| 189 |
+
plt.tight_layout()
|
| 190 |
+
plt.savefig(os.path.join(out_dir, "pred_vs_true.png"))
|
| 191 |
+
plt.close()
|
| 192 |
+
|
| 193 |
+
resid = y_true - y_pred
|
| 194 |
+
plt.figure()
|
| 195 |
+
plt.hist(resid, bins=50)
|
| 196 |
+
plt.xlabel("residual (y_true - y_pred)"); plt.ylabel("count")
|
| 197 |
+
plt.title("Residual Histogram")
|
| 198 |
+
plt.tight_layout()
|
| 199 |
+
plt.savefig(os.path.join(out_dir, "residual_hist.png"))
|
| 200 |
+
plt.close()
|
| 201 |
+
|
| 202 |
+
plt.figure()
|
| 203 |
+
plt.scatter(y_pred, resid, s=8, alpha=0.5)
|
| 204 |
+
plt.xlabel("y_pred"); plt.ylabel("residual")
|
| 205 |
+
plt.title("Residuals vs Prediction")
|
| 206 |
+
plt.tight_layout()
|
| 207 |
+
plt.savefig(os.path.join(out_dir, "residual_vs_pred.png"))
|
| 208 |
+
plt.close()
|
| 209 |
+
|
| 210 |
+
# ============================
|
| 211 |
+
# Optuna objective
|
| 212 |
+
# ============================
|
| 213 |
+
def score_from_metrics(metrics: Dict[str, float], objective: str) -> float:
|
| 214 |
+
if objective == "spearman":
|
| 215 |
+
return metrics["spearman_rho"]
|
| 216 |
+
if objective == "r2":
|
| 217 |
+
return metrics["r2"]
|
| 218 |
+
if objective == "neg_rmse":
|
| 219 |
+
return -metrics["rmse"]
|
| 220 |
+
raise ValueError(f"Unknown objective={objective}")
|
| 221 |
+
|
| 222 |
+
def objective_nn_reg(trial, model_name, train_ds, val_ds, device="cuda:0", objective="spearman"):
|
| 223 |
+
lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
|
| 224 |
+
wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True)
|
| 225 |
+
dropout = trial.suggest_float("dropout", 0.0, 0.5)
|
| 226 |
+
batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
|
| 227 |
+
|
| 228 |
+
in_dim = infer_in_dim(train_ds)
|
| 229 |
+
|
| 230 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
|
| 231 |
+
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
|
| 232 |
+
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
|
| 233 |
+
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
|
| 234 |
+
|
| 235 |
+
if model_name == "mlp":
|
| 236 |
+
hidden = trial.suggest_categorical("hidden", [256, 512, 1024, 2048])
|
| 237 |
+
model = MLPRegressor(in_dim=in_dim, hidden=hidden, dropout=dropout)
|
| 238 |
+
elif model_name == "cnn":
|
| 239 |
+
c = trial.suggest_categorical("channels", [128, 256, 512])
|
| 240 |
+
k = trial.suggest_categorical("kernel", [3, 5, 7])
|
| 241 |
+
layers = trial.suggest_int("layers", 1, 4)
|
| 242 |
+
model = CNNRegressor(in_ch=in_dim, c=c, k=k, layers=layers, dropout=dropout)
|
| 243 |
+
elif model_name == "transformer":
|
| 244 |
+
d = trial.suggest_categorical("d_model", [128, 256, 384])
|
| 245 |
+
nhead = trial.suggest_categorical("nhead", [4, 8])
|
| 246 |
+
layers = trial.suggest_int("layers", 1, 4)
|
| 247 |
+
ff = trial.suggest_categorical("ff", [256, 512, 1024, 1536])
|
| 248 |
+
model = TransformerRegressor(in_dim=in_dim, d_model=d, nhead=nhead, layers=layers, ff=ff, dropout=dropout)
|
| 249 |
+
else:
|
| 250 |
+
raise ValueError(model_name)
|
| 251 |
+
|
| 252 |
+
model = model.to(device)
|
| 253 |
+
|
| 254 |
+
loss_name = trial.suggest_categorical("loss", ["mse", "huber"])
|
| 255 |
+
if loss_name == "mse":
|
| 256 |
+
criterion = nn.MSELoss()
|
| 257 |
+
else:
|
| 258 |
+
delta = trial.suggest_float("huber_delta", 0.5, 5.0, log=True)
|
| 259 |
+
criterion = nn.HuberLoss(delta=delta)
|
| 260 |
+
|
| 261 |
+
optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
|
| 262 |
+
|
| 263 |
+
best_score = -1e18
|
| 264 |
+
patience = 10
|
| 265 |
+
bad = 0
|
| 266 |
+
|
| 267 |
+
for epoch in range(1, 61):
|
| 268 |
+
train_one_epoch_reg(model, train_loader, optim, criterion, device)
|
| 269 |
+
|
| 270 |
+
y_true, y_pred = eval_preds(model, val_loader, device)
|
| 271 |
+
metrics = eval_regression(y_true, y_pred)
|
| 272 |
+
score = score_from_metrics(metrics, objective)
|
| 273 |
+
|
| 274 |
+
# log attrs
|
| 275 |
+
for k, v in metrics.items():
|
| 276 |
+
trial.set_user_attr(f"val_{k}", float(v))
|
| 277 |
+
|
| 278 |
+
trial.report(score, epoch)
|
| 279 |
+
if trial.should_prune():
|
| 280 |
+
raise optuna.TrialPruned()
|
| 281 |
+
|
| 282 |
+
if score > best_score + 1e-6:
|
| 283 |
+
best_score = score
|
| 284 |
+
bad = 0
|
| 285 |
+
else:
|
| 286 |
+
bad += 1
|
| 287 |
+
if bad >= patience:
|
| 288 |
+
break
|
| 289 |
+
|
| 290 |
+
return float(best_score)
|
| 291 |
+
|
| 292 |
+
# ============================
|
| 293 |
+
# Main runner
|
| 294 |
+
# ============================
|
| 295 |
+
def run_optuna_and_refit_nn_reg(dataset_path, out_dir, model_name, n_trials=80, device="cuda:0",
|
| 296 |
+
objective="spearman"):
|
| 297 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 298 |
+
|
| 299 |
+
train_ds, val_ds = load_split(dataset_path)
|
| 300 |
+
print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}")
|
| 301 |
+
|
| 302 |
+
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
|
| 303 |
+
study.optimize(lambda t: objective_nn_reg(t, model_name, train_ds, val_ds, device=device, objective=objective),
|
| 304 |
+
n_trials=n_trials)
|
| 305 |
+
|
| 306 |
+
trials_df = study.trials_dataframe()
|
| 307 |
+
trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
|
| 308 |
+
|
| 309 |
+
best = study.best_trial
|
| 310 |
+
best_params = dict(best.params)
|
| 311 |
+
|
| 312 |
+
# rebuild model from best params
|
| 313 |
+
in_dim = infer_in_dim(train_ds)
|
| 314 |
+
dropout = float(best_params.get("dropout", 0.1))
|
| 315 |
+
if model_name == "mlp":
|
| 316 |
+
model = MLPRegressor(in_dim=in_dim, hidden=int(best_params["hidden"]), dropout=dropout)
|
| 317 |
+
elif model_name == "cnn":
|
| 318 |
+
model = CNNRegressor(in_ch=in_dim, c=int(best_params["channels"]),
|
| 319 |
+
k=int(best_params["kernel"]), layers=int(best_params["layers"]),
|
| 320 |
+
dropout=dropout)
|
| 321 |
+
elif model_name == "transformer":
|
| 322 |
+
model = TransformerRegressor(in_dim=in_dim, d_model=int(best_params["d_model"]),
|
| 323 |
+
nhead=int(best_params["nhead"]), layers=int(best_params["layers"]),
|
| 324 |
+
ff=int(best_params["ff"]), dropout=dropout)
|
| 325 |
+
else:
|
| 326 |
+
raise ValueError(model_name)
|
| 327 |
+
|
| 328 |
+
model = model.to(device)
|
| 329 |
+
|
| 330 |
+
batch_size = int(best_params.get("batch_size", 32))
|
| 331 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
|
| 332 |
+
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
|
| 333 |
+
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
|
| 334 |
+
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
|
| 335 |
+
|
| 336 |
+
# loss
|
| 337 |
+
if best_params.get("loss", "mse") == "mse":
|
| 338 |
+
criterion = nn.MSELoss()
|
| 339 |
+
else:
|
| 340 |
+
criterion = nn.HuberLoss(delta=float(best_params["huber_delta"]))
|
| 341 |
+
|
| 342 |
+
optim = torch.optim.AdamW(model.parameters(), lr=float(best_params["lr"]),
|
| 343 |
+
weight_decay=float(best_params["weight_decay"]))
|
| 344 |
+
|
| 345 |
+
# refit longer with early stopping on the SAME objective
|
| 346 |
+
best_score, bad, patience = -1e18, 0, 15
|
| 347 |
+
best_state = None
|
| 348 |
+
|
| 349 |
+
for epoch in range(1, 201):
|
| 350 |
+
train_one_epoch_reg(model, train_loader, optim, criterion, device)
|
| 351 |
+
|
| 352 |
+
y_true, y_pred = eval_preds(model, val_loader, device)
|
| 353 |
+
metrics = eval_regression(y_true, y_pred)
|
| 354 |
+
score = score_from_metrics(metrics, objective)
|
| 355 |
+
|
| 356 |
+
if score > best_score + 1e-6:
|
| 357 |
+
best_score = score
|
| 358 |
+
bad = 0
|
| 359 |
+
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
| 360 |
+
best_metrics = metrics
|
| 361 |
+
else:
|
| 362 |
+
bad += 1
|
| 363 |
+
if bad >= patience:
|
| 364 |
+
break
|
| 365 |
+
|
| 366 |
+
if best_state is not None:
|
| 367 |
+
model.load_state_dict(best_state)
|
| 368 |
+
|
| 369 |
+
# preds
|
| 370 |
+
y_true_tr, y_pred_tr = eval_preds(model, DataLoader(train_ds, batch_size=64, shuffle=False,
|
| 371 |
+
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True), device)
|
| 372 |
+
y_true_va, y_pred_va = eval_preds(model, val_loader, device)
|
| 373 |
+
|
| 374 |
+
seq_train = np.asarray(train_ds["sequence"]) if "sequence" in train_ds.column_names else None
|
| 375 |
+
seq_val = np.asarray(val_ds["sequence"]) if "sequence" in val_ds.column_names else None
|
| 376 |
+
save_predictions_csv(out_dir, "train", y_true_tr, y_pred_tr, seq_train)
|
| 377 |
+
save_predictions_csv(out_dir, "val", y_true_va, y_pred_va, seq_val)
|
| 378 |
+
plot_regression_diagnostics(out_dir, y_true_va, y_pred_va)
|
| 379 |
+
|
| 380 |
+
# save model
|
| 381 |
+
model_path = os.path.join(out_dir, "best_model.pt")
|
| 382 |
+
torch.save({"state_dict": model.state_dict(), "best_params": best_params, "in_dim": in_dim}, model_path)
|
| 383 |
+
|
| 384 |
+
summary = [
|
| 385 |
+
"=" * 72,
|
| 386 |
+
f"MODEL: {model_name}",
|
| 387 |
+
f"OPTUNA objective: {objective} (direction=maximize)",
|
| 388 |
+
f"Best trial: {best.number}",
|
| 389 |
+
"Best val metrics:",
|
| 390 |
+
json.dumps({k: float(v) for k, v in best_metrics.items()}, indent=2),
|
| 391 |
+
f"Saved model: {model_path}",
|
| 392 |
+
"Best params:",
|
| 393 |
+
json.dumps(best_params, indent=2),
|
| 394 |
+
"=" * 72,
|
| 395 |
+
]
|
| 396 |
+
with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
|
| 397 |
+
f.write("\n".join(summary))
|
| 398 |
+
print("\n".join(summary))
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
if __name__ == "__main__":
|
| 402 |
+
import argparse
|
| 403 |
+
parser = argparse.ArgumentParser()
|
| 404 |
+
parser.add_argument("--dataset_path", type=str, required=True)
|
| 405 |
+
parser.add_argument("--out_dir", type=str, required=True)
|
| 406 |
+
parser.add_argument("--model", type=str, choices=["mlp","cnn","transformer"], required=True)
|
| 407 |
+
parser.add_argument("--n_trials", type=int, default=80)
|
| 408 |
+
parser.add_argument("--objective", type=str, default="spearman",
|
| 409 |
+
choices=["spearman","neg_rmse","r2"])
|
| 410 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
| 411 |
+
args = parser.parse_args()
|
| 412 |
+
|
| 413 |
+
run_optuna_and_refit_nn_reg(
|
| 414 |
+
dataset_path=args.dataset_path,
|
| 415 |
+
out_dir=args.out_dir,
|
| 416 |
+
model_name=args.model,
|
| 417 |
+
n_trials=args.n_trials,
|
| 418 |
+
device=args.device,
|
| 419 |
+
objective=args.objective,
|
| 420 |
+
)
|
training_data_cleaned/data_split.ipynb → training_classifiers/binding_affinity/val_smiles_pooled.csv
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5410a45a7b65def6cfb94c167b07537abd33b5aac4ecdffe162b7ce4e9bc3d19
|
| 3 |
+
size 36525
|
training_data_cleaned/nf_smiles_train.csv → training_classifiers/binding_affinity/val_smiles_unpooled.csv
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cdf71fbb3e7b3b8e8dbfe4ed45b32a2da0049df851f09ee32564825f626cb86c
|
| 3 |
+
size 37187
|
training_data_cleaned/smiles_data_split.ipynb → training_classifiers/binding_affinity/val_wt_pooled.csv
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b194e7b2b97258320323021b3ffe6143133070212a0215ade22fa91b87a3a861
|
| 3 |
+
size 33224
|
training_data_cleaned/nf_smiles_val.csv → training_classifiers/binding_affinity/val_wt_unpooled.csv
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:051325790047e749fbf1daf7bf25a08178297b0c37acaf9439816d09f2b6c1e3
|
| 3 |
+
size 33826
|
training_classifiers/binding_affinity/wt_smiles_pooled/best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:12f956a7bf04ed602c11fd275377afa73f3f0af1982dbe06c607d8ada304b01c
|
| 3 |
+
size 21617397
|
training_classifiers/binding_affinity/wt_smiles_pooled/best_params.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"lr": 0.00011987622631192274,
|
| 3 |
+
"weight_decay": 5.279397670067118e-05,
|
| 4 |
+
"dropout": 0.06773313718640918,
|
| 5 |
+
"hidden_dim": 256,
|
| 6 |
+
"n_heads": 8,
|
| 7 |
+
"n_layers": 3,
|
| 8 |
+
"cls_weight": 0.29331613012593555,
|
| 9 |
+
"batch_size": 16
|
| 10 |
+
}
|
training_classifiers/binding_affinity/wt_smiles_pooled/optuna_trials.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:23161560edf5ad2069302afa1d387819dcfb7010c6ff0437c61ec09a8aa0e0f0
|
| 3 |
+
size 40599
|
training_classifiers/binding_affinity/wt_smiles_unpooled/.ipynb_checkpoints/best_params-checkpoint.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"lr": 6.714904102732621e-05,
|
| 3 |
+
"weight_decay": 6.94348785472601e-08,
|
| 4 |
+
"dropout": 0.20599610484012826,
|
| 5 |
+
"hidden_dim": 768,
|
| 6 |
+
"n_heads": 4,
|
| 7 |
+
"n_layers": 3,
|
| 8 |
+
"cls_weight": 0.26109289573917854,
|
| 9 |
+
"batch_size": 16
|
| 10 |
+
}
|
training_classifiers/binding_affinity/wt_smiles_unpooled/best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3d7ae3d2190b034352a65bda1bce86aa5a96ce3daf74cf10a166f8d9e9af51f0
|
| 3 |
+
size 181183221
|
training_classifiers/binding_affinity/wt_smiles_unpooled/best_params.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"lr": 6.714904102732621e-05,
|
| 3 |
+
"weight_decay": 6.94348785472601e-08,
|
| 4 |
+
"dropout": 0.20599610484012826,
|
| 5 |
+
"hidden_dim": 768,
|
| 6 |
+
"n_heads": 4,
|
| 7 |
+
"n_layers": 3,
|
| 8 |
+
"cls_weight": 0.26109289573917854,
|
| 9 |
+
"batch_size": 16
|
| 10 |
+
}
|
training_classifiers/binding_affinity/wt_smiles_unpooled/optuna_trials.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:11838c42881182dd76b055f06e89b4423659eacd729d695b8a8f4c0a10165da0
|
| 3 |
+
size 40533
|
training_classifiers/binding_affinity/wt_wt_pooled/.ipynb_checkpoints/optuna_trials-checkpoint.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b685b92714882d618b42b582000574d83c3be2fbecbec5e0de6b5476948b96c5
|
| 3 |
+
size 40700
|
training_classifiers/binding_affinity/wt_wt_pooled/best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:636de30f4388efd55e8e625c4f2c71a7629982197ef1e53fecf2a4f640df1ae0
|
| 3 |
+
size 182756085
|
training_classifiers/binding_affinity/wt_wt_pooled/best_params.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"lr": 0.0001730381005812531,
|
| 3 |
+
"weight_decay": 8.736709570411299e-05,
|
| 4 |
+
"dropout": 0.17533811687195416,
|
| 5 |
+
"hidden_dim": 768,
|
| 6 |
+
"n_heads": 4,
|
| 7 |
+
"n_layers": 3,
|
| 8 |
+
"cls_weight": 0.1278591739909013,
|
| 9 |
+
"batch_size": 16
|
| 10 |
+
}
|
training_classifiers/binding_affinity/wt_wt_pooled/optuna_trials.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b685b92714882d618b42b582000574d83c3be2fbecbec5e0de6b5476948b96c5
|
| 3 |
+
size 40700
|
training_classifiers/binding_affinity/wt_wt_unpooled/best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ce8da2406036535794909ffde1f1941843096b7d3c71ba772f9170ab123877f2
|
| 3 |
+
size 69333557
|
training_classifiers/binding_affinity/wt_wt_unpooled/best_params.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"lr": 0.000657577559506255,
|
| 3 |
+
"weight_decay": 3.3209159985473103e-07,
|
| 4 |
+
"dropout": 0.16430662769055482,
|
| 5 |
+
"hidden_dim": 768,
|
| 6 |
+
"n_heads": 8,
|
| 7 |
+
"n_layers": 1,
|
| 8 |
+
"cls_weight": 0.7037398702018655,
|
| 9 |
+
"batch_size": 16
|
| 10 |
+
}
|
training_classifiers/binding_affinity/wt_wt_unpooled/optuna_trials.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b330d37733e684ff780b143f17fd26c498f615dfbab8c5b3df08eae7eb019139
|
| 3 |
+
size 40587
|
training_classifiers/binding_affinity_iptm.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
extract_iptm_affinity_csv_all.py
|
| 4 |
+
|
| 5 |
+
Writes:
|
| 6 |
+
- out_dir/wt_iptm_affinity_all.csv
|
| 7 |
+
- out_dir/smiles_iptm_affinity_all.csv
|
| 8 |
+
|
| 9 |
+
Also prints:
|
| 10 |
+
- N
|
| 11 |
+
- Spearman rho (affinity vs iptm)
|
| 12 |
+
- Pearson r (affinity vs iptm)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pandas as pd
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def corr_stats(df: pd.DataFrame, x: str, y: str):
|
| 21 |
+
# pandas handles NaNs if we already dropped them; still be safe
|
| 22 |
+
xx = pd.to_numeric(df[x], errors="coerce")
|
| 23 |
+
yy = pd.to_numeric(df[y], errors="coerce")
|
| 24 |
+
m = xx.notna() & yy.notna()
|
| 25 |
+
xx = xx[m]
|
| 26 |
+
yy = yy[m]
|
| 27 |
+
n = int(m.sum())
|
| 28 |
+
|
| 29 |
+
# Pearson r
|
| 30 |
+
pearson_r = float(xx.corr(yy, method="pearson")) if n > 1 else float("nan")
|
| 31 |
+
# Spearman rho
|
| 32 |
+
spearman_rho = float(xx.corr(yy, method="spearman")) if n > 1 else float("nan")
|
| 33 |
+
|
| 34 |
+
return {"n": n, "pearson_r": pearson_r, "spearman_rho": spearman_rho}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def clean_one(
|
| 38 |
+
in_csv: Path,
|
| 39 |
+
out_csv: Path,
|
| 40 |
+
iptm_col: str,
|
| 41 |
+
affinity_col: str = "affinity",
|
| 42 |
+
keep_cols=(),
|
| 43 |
+
):
|
| 44 |
+
df = pd.read_csv(in_csv)
|
| 45 |
+
|
| 46 |
+
# affinity + iptm must exist
|
| 47 |
+
need = [affinity_col, iptm_col]
|
| 48 |
+
missing = [c for c in need if c not in df.columns]
|
| 49 |
+
if missing:
|
| 50 |
+
raise ValueError(f"{in_csv} missing columns: {missing}. Found: {list(df.columns)}")
|
| 51 |
+
|
| 52 |
+
# coerce numeric
|
| 53 |
+
df[affinity_col] = pd.to_numeric(df[affinity_col], errors="coerce")
|
| 54 |
+
df[iptm_col] = pd.to_numeric(df[iptm_col], errors="coerce")
|
| 55 |
+
|
| 56 |
+
# drop NaNs in either
|
| 57 |
+
df = df.dropna(subset=[affinity_col, iptm_col]).reset_index(drop=True)
|
| 58 |
+
|
| 59 |
+
# output cols (standardize names)
|
| 60 |
+
out = pd.DataFrame({
|
| 61 |
+
"affinity": df[affinity_col].astype(float),
|
| 62 |
+
"iptm": df[iptm_col].astype(float),
|
| 63 |
+
})
|
| 64 |
+
|
| 65 |
+
# keep split if present (handy for coloring later, but not used for corr)
|
| 66 |
+
if "split" in df.columns:
|
| 67 |
+
out.insert(0, "split", df["split"].astype(str))
|
| 68 |
+
|
| 69 |
+
# optional extras for labeling/debug
|
| 70 |
+
for c in keep_cols:
|
| 71 |
+
if c in df.columns:
|
| 72 |
+
out[c] = df[c]
|
| 73 |
+
|
| 74 |
+
out_csv.parent.mkdir(parents=True, exist_ok=True)
|
| 75 |
+
out.to_csv(out_csv, index=False)
|
| 76 |
+
|
| 77 |
+
stats = corr_stats(out, "iptm", "affinity")
|
| 78 |
+
print(f"[write] {out_csv}")
|
| 79 |
+
print(f" N={stats['n']} | Pearson r={stats['pearson_r']:.4f} | Spearman rho={stats['spearman_rho']:.4f}")
|
| 80 |
+
|
| 81 |
+
# also save stats json next to csv
|
| 82 |
+
stats_path = out_csv.with_suffix(".stats.json")
|
| 83 |
+
with open(stats_path, "w") as f:
|
| 84 |
+
import json
|
| 85 |
+
json.dump(
|
| 86 |
+
{
|
| 87 |
+
"input_csv": str(in_csv),
|
| 88 |
+
"output_csv": str(out_csv),
|
| 89 |
+
"iptm_col": iptm_col,
|
| 90 |
+
"affinity_col": affinity_col,
|
| 91 |
+
**stats,
|
| 92 |
+
},
|
| 93 |
+
f,
|
| 94 |
+
indent=2,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def main():
|
| 99 |
+
import argparse
|
| 100 |
+
ap = argparse.ArgumentParser()
|
| 101 |
+
ap.add_argument("--wt_meta_csv", type=str, required=True)
|
| 102 |
+
ap.add_argument("--smiles_meta_csv", type=str, required=True)
|
| 103 |
+
ap.add_argument("--out_dir", type=str, required=True)
|
| 104 |
+
|
| 105 |
+
ap.add_argument("--wt_iptm_col", type=str, default="wt_iptm_score")
|
| 106 |
+
ap.add_argument("--smiles_iptm_col", type=str, default="smiles_iptm_score")
|
| 107 |
+
ap.add_argument("--affinity_col", type=str, default="affinity")
|
| 108 |
+
args = ap.parse_args()
|
| 109 |
+
|
| 110 |
+
out_dir = Path(args.out_dir)
|
| 111 |
+
|
| 112 |
+
clean_one(
|
| 113 |
+
Path(args.wt_meta_csv),
|
| 114 |
+
out_dir / "wt_iptm_affinity_all.csv",
|
| 115 |
+
iptm_col=args.wt_iptm_col,
|
| 116 |
+
affinity_col=args.affinity_col,
|
| 117 |
+
keep_cols=("seq1", "seq2", "Fasta2SMILES", "REACT_SMILES"),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
clean_one(
|
| 121 |
+
Path(args.smiles_meta_csv),
|
| 122 |
+
out_dir / "smiles_iptm_affinity_all.csv",
|
| 123 |
+
iptm_col=args.smiles_iptm_col,
|
| 124 |
+
affinity_col=args.affinity_col,
|
| 125 |
+
keep_cols=("seq1", "seq2", "Fasta2SMILES", "REACT_SMILES", "smiles_sequence"),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
print(f"\n[DONE] CSVs + stats JSONs in: {out_dir}")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
main()
|
training_classifiers/binding_affinity_split.py
ADDED
|
@@ -0,0 +1,847 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import os
|
| 3 |
+
import math
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import sys
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
# tqdm is optional; we’ll disable it by default in notebooks
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
sys.path.append("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight")
|
| 16 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 17 |
+
|
| 18 |
+
from datasets import Dataset, DatasetDict, Features, Value, Sequence as HFSequence
|
| 19 |
+
from transformers import AutoTokenizer, EsmModel, AutoModelForMaskedLM
|
| 20 |
+
|
| 21 |
+
# -------------------------
|
| 22 |
+
# Config
|
| 23 |
+
# -------------------------
|
| 24 |
+
CSV_PATH = Path("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/c-binding_with_openfold_scores.csv")
|
| 25 |
+
|
| 26 |
+
OUT_ROOT = Path(
|
| 27 |
+
"/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_data_cleaned/binding_affinity"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# WT (seq) embedding model
|
| 31 |
+
WT_MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
|
| 32 |
+
WT_MAX_LEN = 1022
|
| 33 |
+
WT_BATCH = 32
|
| 34 |
+
|
| 35 |
+
# SMILES embedding model + tokenizer
|
| 36 |
+
SMI_MODEL_NAME = "aaronfeller/PeptideCLM-23M-all"
|
| 37 |
+
TOKENIZER_VOCAB = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_vocab.txt"
|
| 38 |
+
TOKENIZER_SPLITS = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_splits.txt"
|
| 39 |
+
SMI_MAX_LEN = 768
|
| 40 |
+
SMI_BATCH = 128
|
| 41 |
+
|
| 42 |
+
# Split config
|
| 43 |
+
TRAIN_FRAC = 0.80
|
| 44 |
+
RANDOM_SEED = 1986
|
| 45 |
+
AFFINITY_Q_BINS = 30
|
| 46 |
+
|
| 47 |
+
# Columns expected in CSV
|
| 48 |
+
COL_SEQ1 = "seq1"
|
| 49 |
+
COL_SEQ2 = "seq2"
|
| 50 |
+
COL_AFF = "affinity"
|
| 51 |
+
COL_F2S = "Fasta2SMILES"
|
| 52 |
+
COL_REACT = "REACT_SMILES"
|
| 53 |
+
COL_WT_IPTM = "wt_iptm_score"
|
| 54 |
+
COL_SMI_IPTM = "smiles_iptm_score"
|
| 55 |
+
|
| 56 |
+
# Device
|
| 57 |
+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 58 |
+
|
| 59 |
+
# -------------------------
|
| 60 |
+
# Quiet / notebook-safe output controls
|
| 61 |
+
# -------------------------
|
| 62 |
+
QUIET = True # suppress most prints
|
| 63 |
+
USE_TQDM = False # disable tqdm bars (recommended in Jupyter to avoid crashing)
|
| 64 |
+
LOG_FILE = None # optionally: OUT_ROOT / "build.log"
|
| 65 |
+
|
| 66 |
+
def log(msg: str):
|
| 67 |
+
if LOG_FILE is not None:
|
| 68 |
+
Path(LOG_FILE).parent.mkdir(parents=True, exist_ok=True)
|
| 69 |
+
with open(LOG_FILE, "a") as f:
|
| 70 |
+
f.write(msg.rstrip() + "\n")
|
| 71 |
+
if not QUIET:
|
| 72 |
+
print(msg)
|
| 73 |
+
|
| 74 |
+
def pbar(it, **kwargs):
|
| 75 |
+
return tqdm(it, **kwargs) if USE_TQDM else it
|
| 76 |
+
|
| 77 |
+
@contextmanager
|
| 78 |
+
def section(title: str):
|
| 79 |
+
log(f"\n=== {title} ===")
|
| 80 |
+
yield
|
| 81 |
+
log(f"=== done: {title} ===")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# -------------------------
|
| 85 |
+
# Helpers
|
| 86 |
+
# -------------------------
|
| 87 |
+
def has_uaa(seq: str) -> bool:
|
| 88 |
+
return "X" in str(seq).upper()
|
| 89 |
+
|
| 90 |
+
def affinity_to_class(a: float) -> str:
|
| 91 |
+
# High: >= 9 ; Moderate: [7, 9) ; Low: < 7
|
| 92 |
+
if a >= 9.0:
|
| 93 |
+
return "High"
|
| 94 |
+
elif a >= 7.0:
|
| 95 |
+
return "Moderate"
|
| 96 |
+
else:
|
| 97 |
+
return "Low"
|
| 98 |
+
|
| 99 |
+
def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
|
| 100 |
+
df = df.copy()
|
| 101 |
+
|
| 102 |
+
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
|
| 103 |
+
df = df.dropna(subset=[COL_AFF]).reset_index(drop=True)
|
| 104 |
+
|
| 105 |
+
df["affinity_class"] = df[COL_AFF].apply(affinity_to_class)
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop")
|
| 109 |
+
strat_col = "aff_bin"
|
| 110 |
+
except Exception:
|
| 111 |
+
df["aff_bin"] = df["affinity_class"]
|
| 112 |
+
strat_col = "aff_bin"
|
| 113 |
+
|
| 114 |
+
rng = np.random.RandomState(RANDOM_SEED)
|
| 115 |
+
|
| 116 |
+
df["split"] = None
|
| 117 |
+
for _, g in df.groupby(strat_col, observed=True):
|
| 118 |
+
idx = g.index.to_numpy()
|
| 119 |
+
rng.shuffle(idx)
|
| 120 |
+
n_train = int(math.floor(len(idx) * TRAIN_FRAC))
|
| 121 |
+
df.loc[idx[:n_train], "split"] = "train"
|
| 122 |
+
df.loc[idx[n_train:], "split"] = "val"
|
| 123 |
+
|
| 124 |
+
df["split"] = df["split"].fillna("train")
|
| 125 |
+
return df
|
| 126 |
+
|
| 127 |
+
def _summ(x):
|
| 128 |
+
x = np.asarray(x, dtype=float)
|
| 129 |
+
x = x[~np.isnan(x)]
|
| 130 |
+
if len(x) == 0:
|
| 131 |
+
return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
|
| 132 |
+
return {
|
| 133 |
+
"n": int(len(x)),
|
| 134 |
+
"mean": float(np.mean(x)),
|
| 135 |
+
"std": float(np.std(x)),
|
| 136 |
+
"p50": float(np.quantile(x, 0.50)),
|
| 137 |
+
"p95": float(np.quantile(x, 0.95)),
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
def _len_stats(seqs):
|
| 141 |
+
lens = np.asarray([len(str(s)) for s in seqs], dtype=float)
|
| 142 |
+
if len(lens) == 0:
|
| 143 |
+
return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
|
| 144 |
+
return {
|
| 145 |
+
"n": int(len(lens)),
|
| 146 |
+
"mean": float(lens.mean()),
|
| 147 |
+
"std": float(lens.std()),
|
| 148 |
+
"p50": float(np.quantile(lens, 0.50)),
|
| 149 |
+
"p95": float(np.quantile(lens, 0.95)),
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
def verify_split_before_embedding(
|
| 153 |
+
df2: pd.DataFrame,
|
| 154 |
+
affinity_col: str,
|
| 155 |
+
split_col: str,
|
| 156 |
+
seq_col: str,
|
| 157 |
+
iptm_col: str,
|
| 158 |
+
aff_class_col: str = "affinity_class",
|
| 159 |
+
aff_bins: int = 30,
|
| 160 |
+
save_report_prefix: str | None = None,
|
| 161 |
+
verbose: bool = False,
|
| 162 |
+
):
|
| 163 |
+
"""
|
| 164 |
+
Notebook-safe: by default prints only ONE line via `log()`.
|
| 165 |
+
Optionally writes CSV reports (stats + class proportions).
|
| 166 |
+
"""
|
| 167 |
+
df2 = df2.copy()
|
| 168 |
+
df2[affinity_col] = pd.to_numeric(df2[affinity_col], errors="coerce")
|
| 169 |
+
df2[iptm_col] = pd.to_numeric(df2[iptm_col], errors="coerce")
|
| 170 |
+
|
| 171 |
+
assert split_col in df2.columns, f"Missing split col: {split_col}"
|
| 172 |
+
assert set(df2[split_col].dropna().unique()).issubset({"train", "val"}), f"Unexpected split values: {df2[split_col].unique()}"
|
| 173 |
+
assert df2[affinity_col].notna().any(), "No valid affinity values after coercion."
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
df2["_aff_bin_dbg"] = pd.qcut(df2[affinity_col], q=aff_bins, duplicates="drop")
|
| 177 |
+
except Exception:
|
| 178 |
+
df2["_aff_bin_dbg"] = df2[aff_class_col].astype(str)
|
| 179 |
+
|
| 180 |
+
tr = df2[df2[split_col] == "train"].reset_index(drop=True)
|
| 181 |
+
va = df2[df2[split_col] == "val"].reset_index(drop=True)
|
| 182 |
+
|
| 183 |
+
tr_aff = _summ(tr[affinity_col].to_numpy())
|
| 184 |
+
va_aff = _summ(va[affinity_col].to_numpy())
|
| 185 |
+
tr_len = _len_stats(tr[seq_col].tolist())
|
| 186 |
+
va_len = _len_stats(va[seq_col].tolist())
|
| 187 |
+
|
| 188 |
+
# bin drift
|
| 189 |
+
bin_ct = (
|
| 190 |
+
df2.groupby([split_col, "_aff_bin_dbg"])
|
| 191 |
+
.size()
|
| 192 |
+
.groupby(level=0)
|
| 193 |
+
.apply(lambda s: s / s.sum())
|
| 194 |
+
)
|
| 195 |
+
tr_bins = bin_ct.loc["train"]
|
| 196 |
+
va_bins = bin_ct.loc["val"]
|
| 197 |
+
all_bins = tr_bins.index.union(va_bins.index)
|
| 198 |
+
tr_bins = tr_bins.reindex(all_bins, fill_value=0.0)
|
| 199 |
+
va_bins = va_bins.reindex(all_bins, fill_value=0.0)
|
| 200 |
+
max_bin_diff = float(np.max(np.abs(tr_bins.values - va_bins.values)))
|
| 201 |
+
|
| 202 |
+
msg = (
|
| 203 |
+
f"[split-check] rows={len(df2)} train={len(tr)} val={len(va)} | "
|
| 204 |
+
f"aff(mean±std) train={tr_aff['mean']:.3f}±{tr_aff['std']:.3f} val={va_aff['mean']:.3f}±{va_aff['std']:.3f} | "
|
| 205 |
+
f"len(p50/p95) train={tr_len['p50']:.1f}/{tr_len['p95']:.1f} val={va_len['p50']:.1f}/{va_len['p95']:.1f} | "
|
| 206 |
+
f"max_bin_diff={max_bin_diff:.4f}"
|
| 207 |
+
)
|
| 208 |
+
log(msg)
|
| 209 |
+
|
| 210 |
+
if verbose and (not QUIET):
|
| 211 |
+
class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
|
| 212 |
+
class_prop = class_ct.div(class_ct.sum(axis=1), axis=0)
|
| 213 |
+
print("\n[verbose] affinity_class counts:\n", class_ct)
|
| 214 |
+
print("\n[verbose] affinity_class proportions:\n", class_prop.round(4))
|
| 215 |
+
|
| 216 |
+
if save_report_prefix is not None:
|
| 217 |
+
out = Path(save_report_prefix)
|
| 218 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 219 |
+
|
| 220 |
+
stats_df = pd.DataFrame([
|
| 221 |
+
{"split": "train", **{f"aff_{k}": v for k, v in tr_aff.items()}, **{f"len_{k}": v for k, v in tr_len.items()}},
|
| 222 |
+
{"split": "val", **{f"aff_{k}": v for k, v in va_aff.items()}, **{f"len_{k}": v for k, v in va_len.items()}},
|
| 223 |
+
])
|
| 224 |
+
class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
|
| 225 |
+
class_prop = class_ct.div(class_ct.sum(axis=1), axis=0).reset_index()
|
| 226 |
+
|
| 227 |
+
stats_df.to_csv(out.with_suffix(".stats.csv"), index=False)
|
| 228 |
+
class_prop.to_csv(out.with_suffix(".class_prop.csv"), index=False)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# -------------------------
|
| 232 |
+
# WT pooled (ESM2)
|
| 233 |
+
# -------------------------
|
| 234 |
+
@torch.no_grad()
|
| 235 |
+
def wt_pooled_embeddings(seqs, tokenizer, model, batch_size=32, max_length=1022):
|
| 236 |
+
embs = []
|
| 237 |
+
for i in pbar(range(0, len(seqs), batch_size)):
|
| 238 |
+
batch = seqs[i:i + batch_size]
|
| 239 |
+
inputs = tokenizer(
|
| 240 |
+
batch,
|
| 241 |
+
padding=True,
|
| 242 |
+
truncation=True,
|
| 243 |
+
max_length=max_length,
|
| 244 |
+
return_tensors="pt",
|
| 245 |
+
)
|
| 246 |
+
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
| 247 |
+
out = model(**inputs)
|
| 248 |
+
h = out.last_hidden_state # (B, L, H)
|
| 249 |
+
|
| 250 |
+
attn = inputs["attention_mask"].unsqueeze(-1) # (B, L, 1)
|
| 251 |
+
summed = (h * attn).sum(dim=1) # (B, H)
|
| 252 |
+
denom = attn.sum(dim=1).clamp(min=1e-9) # (B, 1)
|
| 253 |
+
pooled = (summed / denom).detach().cpu().numpy()
|
| 254 |
+
embs.append(pooled)
|
| 255 |
+
|
| 256 |
+
return np.vstack(embs)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# -------------------------
|
| 260 |
+
# WT unpooled (ESM2)
|
| 261 |
+
# -------------------------
|
| 262 |
+
@torch.no_grad()
|
| 263 |
+
def wt_unpooled_one(seq, tokenizer, model, cls_id, eos_id, max_length=1022):
|
| 264 |
+
tok = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt")
|
| 265 |
+
tok = {k: v.to(DEVICE) for k, v in tok.items()}
|
| 266 |
+
out = model(**tok)
|
| 267 |
+
h = out.last_hidden_state[0] # (L, H)
|
| 268 |
+
attn = tok["attention_mask"][0].bool() # (L,)
|
| 269 |
+
ids = tok["input_ids"][0]
|
| 270 |
+
|
| 271 |
+
keep = attn.clone()
|
| 272 |
+
if cls_id is not None:
|
| 273 |
+
keep &= (ids != cls_id)
|
| 274 |
+
if eos_id is not None:
|
| 275 |
+
keep &= (ids != eos_id)
|
| 276 |
+
|
| 277 |
+
return h[keep].detach().cpu().to(torch.float16).numpy()
|
| 278 |
+
|
| 279 |
+
def build_wt_unpooled_dataset(df_split: pd.DataFrame, out_dir: Path, tokenizer, model):
|
| 280 |
+
"""
|
| 281 |
+
Expects df_split to have:
|
| 282 |
+
- target_sequence (seq1)
|
| 283 |
+
- sequence (binder seq2; WT binder)
|
| 284 |
+
- label, affinity_class, COL_AFF, COL_WT_IPTM
|
| 285 |
+
Saves a dataset where each row contains BOTH:
|
| 286 |
+
- target_embedding (Lt,H), target_attention_mask, target_length
|
| 287 |
+
- binder_embedding (Lb,H), binder_attention_mask, binder_length
|
| 288 |
+
"""
|
| 289 |
+
cls_id = tokenizer.cls_token_id
|
| 290 |
+
eos_id = tokenizer.eos_token_id
|
| 291 |
+
H = model.config.hidden_size
|
| 292 |
+
|
| 293 |
+
features = Features({
|
| 294 |
+
"target_sequence": Value("string"),
|
| 295 |
+
"sequence": Value("string"),
|
| 296 |
+
"label": Value("float32"),
|
| 297 |
+
"affinity": Value("float32"),
|
| 298 |
+
"affinity_class": Value("string"),
|
| 299 |
+
|
| 300 |
+
"target_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
|
| 301 |
+
"target_attention_mask": HFSequence(Value("int8")),
|
| 302 |
+
"target_length": Value("int64"),
|
| 303 |
+
|
| 304 |
+
"binder_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
|
| 305 |
+
"binder_attention_mask": HFSequence(Value("int8")),
|
| 306 |
+
"binder_length": Value("int64"),
|
| 307 |
+
|
| 308 |
+
COL_WT_IPTM: Value("float32"),
|
| 309 |
+
COL_AFF: Value("float32"),
|
| 310 |
+
})
|
| 311 |
+
|
| 312 |
+
def gen_rows(df: pd.DataFrame):
|
| 313 |
+
for r in pbar(df.itertuples(index=False), total=len(df)):
|
| 314 |
+
tgt = str(getattr(r, "target_sequence")).strip()
|
| 315 |
+
bnd = str(getattr(r, "sequence")).strip()
|
| 316 |
+
|
| 317 |
+
y = float(getattr(r, "label"))
|
| 318 |
+
aff = float(getattr(r, COL_AFF))
|
| 319 |
+
acls = str(getattr(r, "affinity_class"))
|
| 320 |
+
|
| 321 |
+
iptm = getattr(r, COL_WT_IPTM)
|
| 322 |
+
iptm = float(iptm) if pd.notna(iptm) else np.nan
|
| 323 |
+
|
| 324 |
+
# token embeddings for target + binder (both ESM)
|
| 325 |
+
t_emb = wt_unpooled_one(tgt, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lt,H)
|
| 326 |
+
b_emb = wt_unpooled_one(bnd, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lb,H)
|
| 327 |
+
|
| 328 |
+
t_list = t_emb.tolist()
|
| 329 |
+
b_list = b_emb.tolist()
|
| 330 |
+
Lt = len(t_list)
|
| 331 |
+
Lb = len(b_list)
|
| 332 |
+
|
| 333 |
+
yield {
|
| 334 |
+
"target_sequence": tgt,
|
| 335 |
+
"sequence": bnd,
|
| 336 |
+
"label": np.float32(y),
|
| 337 |
+
"affinity": np.float32(aff),
|
| 338 |
+
"affinity_class": acls,
|
| 339 |
+
|
| 340 |
+
"target_embedding": t_list,
|
| 341 |
+
"target_attention_mask": [1] * Lt,
|
| 342 |
+
"target_length": int(Lt),
|
| 343 |
+
|
| 344 |
+
"binder_embedding": b_list,
|
| 345 |
+
"binder_attention_mask": [1] * Lb,
|
| 346 |
+
"binder_length": int(Lb),
|
| 347 |
+
|
| 348 |
+
COL_WT_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
|
| 349 |
+
COL_AFF: np.float32(aff),
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 353 |
+
ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
|
| 354 |
+
ds.save_to_disk(str(out_dir), max_shard_size="1GB")
|
| 355 |
+
return ds
|
| 356 |
+
|
| 357 |
+
def build_smiles_unpooled_paired_dataset(df_split: pd.DataFrame, out_dir: Path, wt_tokenizer, wt_model_unpooled,
|
| 358 |
+
smi_tok, smi_roformer):
|
| 359 |
+
"""
|
| 360 |
+
df_split must have:
|
| 361 |
+
- target_sequence (seq1)
|
| 362 |
+
- sequence (binder smiles string)
|
| 363 |
+
- label, affinity_class, COL_AFF, COL_SMI_IPTM
|
| 364 |
+
Saves rows with:
|
| 365 |
+
target_embedding (Lt,Ht) from ESM
|
| 366 |
+
binder_embedding (Lb,Hb) from PeptideCLM
|
| 367 |
+
"""
|
| 368 |
+
cls_id = wt_tokenizer.cls_token_id
|
| 369 |
+
eos_id = wt_tokenizer.eos_token_id
|
| 370 |
+
Ht = wt_model_unpooled.config.hidden_size
|
| 371 |
+
|
| 372 |
+
# Infer Hb from one forward pass? easiest: run one mini batch outside in main if you want.
|
| 373 |
+
# Here: we’ll infer from model config if available.
|
| 374 |
+
Hb = getattr(smi_roformer.config, "hidden_size", None)
|
| 375 |
+
if Hb is None:
|
| 376 |
+
Hb = getattr(smi_roformer.config, "dim", None)
|
| 377 |
+
if Hb is None:
|
| 378 |
+
raise ValueError("Cannot infer Hb from smi_roformer config; print(smi_roformer.config) and set Hb manually.")
|
| 379 |
+
|
| 380 |
+
features = Features({
|
| 381 |
+
"target_sequence": Value("string"),
|
| 382 |
+
"sequence": Value("string"),
|
| 383 |
+
"label": Value("float32"),
|
| 384 |
+
"affinity": Value("float32"),
|
| 385 |
+
"affinity_class": Value("string"),
|
| 386 |
+
|
| 387 |
+
"target_embedding": HFSequence(HFSequence(Value("float16"), length=Ht)),
|
| 388 |
+
"target_attention_mask": HFSequence(Value("int8")),
|
| 389 |
+
"target_length": Value("int64"),
|
| 390 |
+
|
| 391 |
+
"binder_embedding": HFSequence(HFSequence(Value("float16"), length=Hb)),
|
| 392 |
+
"binder_attention_mask": HFSequence(Value("int8")),
|
| 393 |
+
"binder_length": Value("int64"),
|
| 394 |
+
|
| 395 |
+
COL_SMI_IPTM: Value("float32"),
|
| 396 |
+
COL_AFF: Value("float32"),
|
| 397 |
+
})
|
| 398 |
+
|
| 399 |
+
def gen_rows(df: pd.DataFrame):
|
| 400 |
+
for r in pbar(df.itertuples(index=False), total=len(df)):
|
| 401 |
+
tgt = str(getattr(r, "target_sequence")).strip()
|
| 402 |
+
bnd = str(getattr(r, "sequence")).strip()
|
| 403 |
+
|
| 404 |
+
y = float(getattr(r, "label"))
|
| 405 |
+
aff = float(getattr(r, COL_AFF))
|
| 406 |
+
acls = str(getattr(r, "affinity_class"))
|
| 407 |
+
|
| 408 |
+
iptm = getattr(r, COL_SMI_IPTM)
|
| 409 |
+
iptm = float(iptm) if pd.notna(iptm) else np.nan
|
| 410 |
+
|
| 411 |
+
# target token embeddings (ESM)
|
| 412 |
+
t_emb = wt_unpooled_one(tgt, wt_tokenizer, wt_model_unpooled, cls_id, eos_id, max_length=WT_MAX_LEN)
|
| 413 |
+
t_list = t_emb.tolist()
|
| 414 |
+
Lt = len(t_list)
|
| 415 |
+
|
| 416 |
+
# binder token embeddings (PeptideCLM) — single-item batch
|
| 417 |
+
_, tok_list, mask_list, lengths = smiles_embed_batch_return_both(
|
| 418 |
+
[bnd], smi_tok, smi_roformer, max_length=SMI_MAX_LEN
|
| 419 |
+
)
|
| 420 |
+
b_emb = tok_list[0] # np.float16 (Lb, Hb)
|
| 421 |
+
b_list = b_emb.tolist()
|
| 422 |
+
Lb = int(lengths[0])
|
| 423 |
+
b_mask = mask_list[0].astype(np.int8).tolist()
|
| 424 |
+
|
| 425 |
+
yield {
|
| 426 |
+
"target_sequence": tgt,
|
| 427 |
+
"sequence": bnd,
|
| 428 |
+
"label": np.float32(y),
|
| 429 |
+
"affinity": np.float32(aff),
|
| 430 |
+
"affinity_class": acls,
|
| 431 |
+
|
| 432 |
+
"target_embedding": t_list,
|
| 433 |
+
"target_attention_mask": [1] * Lt,
|
| 434 |
+
"target_length": int(Lt),
|
| 435 |
+
|
| 436 |
+
"binder_embedding": b_list,
|
| 437 |
+
"binder_attention_mask": [int(x) for x in b_mask],
|
| 438 |
+
"binder_length": int(Lb),
|
| 439 |
+
|
| 440 |
+
COL_SMI_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
|
| 441 |
+
COL_AFF: np.float32(aff),
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 445 |
+
ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
|
| 446 |
+
ds.save_to_disk(str(out_dir), max_shard_size="1GB")
|
| 447 |
+
return ds
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
# -------------------------
|
| 451 |
+
# SMILES pooled + unpooled (PeptideCLM)
|
| 452 |
+
# -------------------------
|
| 453 |
+
def get_special_ids(tokenizer_obj):
|
| 454 |
+
cand = [
|
| 455 |
+
getattr(tokenizer_obj, "pad_token_id", None),
|
| 456 |
+
getattr(tokenizer_obj, "cls_token_id", None),
|
| 457 |
+
getattr(tokenizer_obj, "sep_token_id", None),
|
| 458 |
+
getattr(tokenizer_obj, "bos_token_id", None),
|
| 459 |
+
getattr(tokenizer_obj, "eos_token_id", None),
|
| 460 |
+
getattr(tokenizer_obj, "mask_token_id", None),
|
| 461 |
+
]
|
| 462 |
+
return sorted({x for x in cand if x is not None})
|
| 463 |
+
|
| 464 |
+
@torch.no_grad()
|
| 465 |
+
def smiles_embed_batch_return_both(batch_sequences, tokenizer_obj, model_roformer, max_length):
|
| 466 |
+
tok = tokenizer_obj(
|
| 467 |
+
batch_sequences,
|
| 468 |
+
return_tensors="pt",
|
| 469 |
+
padding=True,
|
| 470 |
+
truncation=True,
|
| 471 |
+
max_length=max_length,
|
| 472 |
+
)
|
| 473 |
+
input_ids = tok["input_ids"].to(DEVICE)
|
| 474 |
+
attention_mask = tok["attention_mask"].to(DEVICE)
|
| 475 |
+
|
| 476 |
+
outputs = model_roformer(input_ids=input_ids, attention_mask=attention_mask)
|
| 477 |
+
last_hidden = outputs.last_hidden_state # (B, L, H)
|
| 478 |
+
|
| 479 |
+
special_ids = get_special_ids(tokenizer_obj)
|
| 480 |
+
valid = attention_mask.bool()
|
| 481 |
+
if len(special_ids) > 0:
|
| 482 |
+
sid = torch.tensor(special_ids, device=DEVICE, dtype=torch.long)
|
| 483 |
+
if hasattr(torch, "isin"):
|
| 484 |
+
valid = valid & (~torch.isin(input_ids, sid))
|
| 485 |
+
else:
|
| 486 |
+
m = torch.zeros_like(valid)
|
| 487 |
+
for s in special_ids:
|
| 488 |
+
m |= (input_ids == s)
|
| 489 |
+
valid = valid & (~m)
|
| 490 |
+
|
| 491 |
+
valid_f = valid.unsqueeze(-1).float()
|
| 492 |
+
summed = torch.sum(last_hidden * valid_f, dim=1)
|
| 493 |
+
denom = torch.clamp(valid_f.sum(dim=1), min=1e-9)
|
| 494 |
+
pooled = (summed / denom).detach().cpu().numpy()
|
| 495 |
+
|
| 496 |
+
token_emb_list, mask_list, lengths = [], [], []
|
| 497 |
+
for b in range(last_hidden.shape[0]):
|
| 498 |
+
emb = last_hidden[b, valid[b]] # (Li, H)
|
| 499 |
+
token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy())
|
| 500 |
+
li = emb.shape[0]
|
| 501 |
+
lengths.append(int(li))
|
| 502 |
+
mask_list.append(np.ones((li,), dtype=np.int8))
|
| 503 |
+
|
| 504 |
+
return pooled, token_emb_list, mask_list, lengths
|
| 505 |
+
|
| 506 |
+
def smiles_generate_embeddings_batched_both(seqs, tokenizer_obj, model_roformer, batch_size, max_length):
|
| 507 |
+
pooled_all = []
|
| 508 |
+
token_emb_all = []
|
| 509 |
+
mask_all = []
|
| 510 |
+
lengths_all = []
|
| 511 |
+
|
| 512 |
+
for i in pbar(range(0, len(seqs), batch_size)):
|
| 513 |
+
batch = seqs[i:i + batch_size]
|
| 514 |
+
pooled, tok_list, m_list, lens = smiles_embed_batch_return_both(
|
| 515 |
+
batch, tokenizer_obj, model_roformer, max_length
|
| 516 |
+
)
|
| 517 |
+
pooled_all.append(pooled)
|
| 518 |
+
token_emb_all.extend(tok_list)
|
| 519 |
+
mask_all.extend(m_list)
|
| 520 |
+
lengths_all.extend(lens)
|
| 521 |
+
|
| 522 |
+
return np.vstack(pooled_all), token_emb_all, mask_all, lengths_all
|
| 523 |
+
|
| 524 |
+
# -------------------------
|
| 525 |
+
# Target embedding cache (NO extra ESM runs)
|
| 526 |
+
# We will compute target pooled embeddings ONCE from WT view, then reuse for SMILES.
|
| 527 |
+
# -------------------------
|
| 528 |
+
def build_target_cache_from_wt_view(wt_view_train: pd.DataFrame, wt_view_val: pd.DataFrame):
|
| 529 |
+
wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
|
| 530 |
+
wt_model = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
|
| 531 |
+
|
| 532 |
+
# compute target pooled embeddings once
|
| 533 |
+
tgt_wt_train = wt_view_train["target_sequence"].astype(str).tolist()
|
| 534 |
+
tgt_wt_val = wt_view_val["target_sequence"].astype(str).tolist()
|
| 535 |
+
|
| 536 |
+
wt_train_tgt_emb = wt_pooled_embeddings(
|
| 537 |
+
tgt_wt_train, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
|
| 538 |
+
)
|
| 539 |
+
wt_val_tgt_emb = wt_pooled_embeddings(
|
| 540 |
+
tgt_wt_val, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
# build dict: target_sequence -> embedding (float32 array)
|
| 544 |
+
# if duplicates exist, last wins; you can add checks if needed
|
| 545 |
+
train_map = {s: e for s, e in zip(tgt_wt_train, wt_train_tgt_emb)}
|
| 546 |
+
val_map = {s: e for s, e in zip(tgt_wt_val, wt_val_tgt_emb)}
|
| 547 |
+
return wt_tok, wt_model, wt_train_tgt_emb, wt_val_tgt_emb, train_map, val_map
|
| 548 |
+
# -------------------------
|
| 549 |
+
# Main
|
| 550 |
+
# -------------------------
|
| 551 |
+
def main():
|
| 552 |
+
log(f"[INFO] DEVICE: {DEVICE}")
|
| 553 |
+
OUT_ROOT.mkdir(parents=True, exist_ok=True)
|
| 554 |
+
|
| 555 |
+
# 1) Load
|
| 556 |
+
with section("load csv + dedup"):
|
| 557 |
+
df = pd.read_csv(CSV_PATH)
|
| 558 |
+
for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]:
|
| 559 |
+
if c in df.columns:
|
| 560 |
+
df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x)
|
| 561 |
+
|
| 562 |
+
# Dedup on the full identity tuple you want
|
| 563 |
+
DEDUP_COLS = [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]
|
| 564 |
+
df = df.drop_duplicates(subset=DEDUP_COLS).reset_index(drop=True)
|
| 565 |
+
|
| 566 |
+
print("Rows after dedup on", DEDUP_COLS, ":", len(df))
|
| 567 |
+
|
| 568 |
+
need = [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM]
|
| 569 |
+
missing = [c for c in need if c not in df.columns]
|
| 570 |
+
if missing:
|
| 571 |
+
raise ValueError(f"Missing required columns: {missing}")
|
| 572 |
+
|
| 573 |
+
# numeric affinity for both branches
|
| 574 |
+
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
|
| 575 |
+
|
| 576 |
+
# 2) Build WT subset + SMILES subset separately (NO global dropping)
|
| 577 |
+
with section("prepare wt/smiles subsets"):
|
| 578 |
+
# WT: requires a canonical peptide sequence (no X) + affinity
|
| 579 |
+
df_wt = df.copy()
|
| 580 |
+
df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip()
|
| 581 |
+
df_wt = df_wt.dropna(subset=[COL_AFF]).reset_index(drop=True)
|
| 582 |
+
df_wt = df_wt[df_wt["wt_sequence"].notna() & (df_wt["wt_sequence"] != "")]
|
| 583 |
+
df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)].reset_index(drop=True)
|
| 584 |
+
|
| 585 |
+
# SMILES: requires affinity + a usable picked SMILES (UAA->REACT, else->Fasta2SMILES)
|
| 586 |
+
df_smi = df.copy()
|
| 587 |
+
df_smi = df_smi.dropna(subset=[COL_AFF]).reset_index(drop=True)
|
| 588 |
+
df_smi = df_smi[
|
| 589 |
+
pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna()
|
| 590 |
+
].reset_index(drop=True) # empty iptm means sth wrong with their smiles sequenc
|
| 591 |
+
|
| 592 |
+
is_uaa = df_smi[COL_SEQ2].astype(str).str.contains("X", case=False, na=False)
|
| 593 |
+
df_smi["smiles_sequence"] = np.where(is_uaa, df_smi[COL_REACT], df_smi[COL_F2S])
|
| 594 |
+
df_smi["smiles_sequence"] = df_smi["smiles_sequence"].astype(str).str.strip()
|
| 595 |
+
df_smi = df_smi[df_smi["smiles_sequence"].notna() & (df_smi["smiles_sequence"] != "")]
|
| 596 |
+
df_smi = df_smi[~df_smi["smiles_sequence"].isin(["nan", "None"])].reset_index(drop=True)
|
| 597 |
+
|
| 598 |
+
log(f"[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)} (after per-branch filtering)")
|
| 599 |
+
|
| 600 |
+
# 3) Split separately (different sizes and memberships are expected)
|
| 601 |
+
with section("split wt and smiles separately"):
|
| 602 |
+
df_wt2 = make_distribution_matched_split(df_wt)
|
| 603 |
+
df_smi2 = make_distribution_matched_split(df_smi)
|
| 604 |
+
|
| 605 |
+
# save split tables
|
| 606 |
+
wt_split_csv = OUT_ROOT / "binding_affinity_wt_meta_with_split.csv"
|
| 607 |
+
smi_split_csv = OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv"
|
| 608 |
+
df_wt2.to_csv(wt_split_csv, index=False)
|
| 609 |
+
df_smi2.to_csv(smi_split_csv, index=False)
|
| 610 |
+
log(f"Saved WT split meta: {wt_split_csv}")
|
| 611 |
+
log(f"Saved SMILES split meta: {smi_split_csv}")
|
| 612 |
+
|
| 613 |
+
# lightweight double-check (one-line)
|
| 614 |
+
verify_split_before_embedding(
|
| 615 |
+
df2=df_wt2,
|
| 616 |
+
affinity_col=COL_AFF,
|
| 617 |
+
split_col="split",
|
| 618 |
+
seq_col="wt_sequence",
|
| 619 |
+
iptm_col=COL_WT_IPTM,
|
| 620 |
+
aff_class_col="affinity_class",
|
| 621 |
+
aff_bins=AFFINITY_Q_BINS,
|
| 622 |
+
save_report_prefix=str(OUT_ROOT / "wt_split_doublecheck_report"),
|
| 623 |
+
verbose=False,
|
| 624 |
+
)
|
| 625 |
+
verify_split_before_embedding(
|
| 626 |
+
df2=df_smi2,
|
| 627 |
+
affinity_col=COL_AFF,
|
| 628 |
+
split_col="split",
|
| 629 |
+
seq_col="smiles_sequence",
|
| 630 |
+
iptm_col=COL_SMI_IPTM,
|
| 631 |
+
aff_class_col="affinity_class",
|
| 632 |
+
aff_bins=AFFINITY_Q_BINS,
|
| 633 |
+
save_report_prefix=str(OUT_ROOT / "smiles_split_doublecheck_report"),
|
| 634 |
+
verbose=False,
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
# Prepare split views
|
| 638 |
+
def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame:
|
| 639 |
+
out = df_in.copy()
|
| 640 |
+
out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() # <-- NEW
|
| 641 |
+
out["sequence"] = out[binder_seq_col].astype(str).str.strip() # binder
|
| 642 |
+
out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce")
|
| 643 |
+
out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce")
|
| 644 |
+
out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce")
|
| 645 |
+
out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True)
|
| 646 |
+
return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]]
|
| 647 |
+
|
| 648 |
+
wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM)
|
| 649 |
+
smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM)
|
| 650 |
+
|
| 651 |
+
# -------------------------
|
| 652 |
+
# Split views
|
| 653 |
+
# -------------------------
|
| 654 |
+
wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True)
|
| 655 |
+
wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True)
|
| 656 |
+
smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True)
|
| 657 |
+
smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True)
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
# =========================
|
| 661 |
+
# TARGET pooled embeddings (ESM) — SEPARATE per branch
|
| 662 |
+
# =========================
|
| 663 |
+
with section("TARGET pooled embeddings (ESM) — WT + SMILES separately"):
|
| 664 |
+
wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
|
| 665 |
+
wt_esm = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
|
| 666 |
+
|
| 667 |
+
# ---- WT targets ----
|
| 668 |
+
wt_train_tgt_emb = wt_pooled_embeddings(
|
| 669 |
+
wt_train["target_sequence"].astype(str).str.strip().tolist(),
|
| 670 |
+
wt_tok, wt_esm,
|
| 671 |
+
batch_size=WT_BATCH,
|
| 672 |
+
max_length=WT_MAX_LEN,
|
| 673 |
+
).astype(np.float32)
|
| 674 |
+
|
| 675 |
+
wt_val_tgt_emb = wt_pooled_embeddings(
|
| 676 |
+
wt_val["target_sequence"].astype(str).str.strip().tolist(),
|
| 677 |
+
wt_tok, wt_esm,
|
| 678 |
+
batch_size=WT_BATCH,
|
| 679 |
+
max_length=WT_MAX_LEN,
|
| 680 |
+
).astype(np.float32)
|
| 681 |
+
|
| 682 |
+
# ---- SMILES targets (independent; may include UAA-only targets) ----
|
| 683 |
+
smi_train_tgt_emb = wt_pooled_embeddings(
|
| 684 |
+
smi_train["target_sequence"].astype(str).str.strip().tolist(),
|
| 685 |
+
wt_tok, wt_esm,
|
| 686 |
+
batch_size=WT_BATCH,
|
| 687 |
+
max_length=WT_MAX_LEN,
|
| 688 |
+
).astype(np.float32)
|
| 689 |
+
|
| 690 |
+
smi_val_tgt_emb = wt_pooled_embeddings(
|
| 691 |
+
smi_val["target_sequence"].astype(str).str.strip().tolist(),
|
| 692 |
+
wt_tok, wt_esm,
|
| 693 |
+
batch_size=WT_BATCH,
|
| 694 |
+
max_length=WT_MAX_LEN,
|
| 695 |
+
).astype(np.float32)
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
# =========================
|
| 699 |
+
# WT pooled binder embeddings (binder = WT peptide)
|
| 700 |
+
# =========================
|
| 701 |
+
with section("WT pooled binder embeddings + save"):
|
| 702 |
+
wt_train_emb = wt_pooled_embeddings(
|
| 703 |
+
wt_train["sequence"].astype(str).str.strip().tolist(),
|
| 704 |
+
wt_tok, wt_esm,
|
| 705 |
+
batch_size=WT_BATCH,
|
| 706 |
+
max_length=WT_MAX_LEN,
|
| 707 |
+
).astype(np.float32)
|
| 708 |
+
|
| 709 |
+
wt_val_emb = wt_pooled_embeddings(
|
| 710 |
+
wt_val["sequence"].astype(str).str.strip().tolist(),
|
| 711 |
+
wt_tok, wt_esm,
|
| 712 |
+
batch_size=WT_BATCH,
|
| 713 |
+
max_length=WT_MAX_LEN,
|
| 714 |
+
).astype(np.float32)
|
| 715 |
+
|
| 716 |
+
wt_train_ds = Dataset.from_dict({
|
| 717 |
+
"target_sequence": wt_train["target_sequence"].tolist(),
|
| 718 |
+
"sequence": wt_train["sequence"].tolist(),
|
| 719 |
+
"label": wt_train["label"].astype(float).tolist(),
|
| 720 |
+
"target_embedding": wt_train_tgt_emb,
|
| 721 |
+
"embedding": wt_train_emb,
|
| 722 |
+
COL_WT_IPTM: wt_train[COL_WT_IPTM].astype(float).tolist(),
|
| 723 |
+
COL_AFF: wt_train[COL_AFF].astype(float).tolist(),
|
| 724 |
+
"affinity_class": wt_train["affinity_class"].tolist(),
|
| 725 |
+
})
|
| 726 |
+
|
| 727 |
+
wt_val_ds = Dataset.from_dict({
|
| 728 |
+
"target_sequence": wt_val["target_sequence"].tolist(),
|
| 729 |
+
"sequence": wt_val["sequence"].tolist(),
|
| 730 |
+
"label": wt_val["label"].astype(float).tolist(),
|
| 731 |
+
"target_embedding": wt_val_tgt_emb,
|
| 732 |
+
"embedding": wt_val_emb,
|
| 733 |
+
COL_WT_IPTM: wt_val[COL_WT_IPTM].astype(float).tolist(),
|
| 734 |
+
COL_AFF: wt_val[COL_AFF].astype(float).tolist(),
|
| 735 |
+
"affinity_class": wt_val["affinity_class"].tolist(),
|
| 736 |
+
})
|
| 737 |
+
|
| 738 |
+
wt_pooled_dd = DatasetDict({"train": wt_train_ds, "val": wt_val_ds})
|
| 739 |
+
wt_pooled_out = OUT_ROOT / "pair_wt_wt_pooled"
|
| 740 |
+
wt_pooled_dd.save_to_disk(str(wt_pooled_out))
|
| 741 |
+
log(f"Saved WT pooled -> {wt_pooled_out}")
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
# =========================
|
| 745 |
+
# SMILES pooled binder embeddings (binder = SMILES via PeptideCLM)
|
| 746 |
+
# =========================
|
| 747 |
+
with section("SMILES pooled binder embeddings + save"):
|
| 748 |
+
smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
|
| 749 |
+
smi_roformer = (
|
| 750 |
+
AutoModelForMaskedLM
|
| 751 |
+
.from_pretrained(SMI_MODEL_NAME)
|
| 752 |
+
.roformer
|
| 753 |
+
.to(DEVICE)
|
| 754 |
+
.eval()
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
smi_train_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
|
| 758 |
+
smi_train["sequence"].astype(str).str.strip().tolist(),
|
| 759 |
+
smi_tok, smi_roformer,
|
| 760 |
+
batch_size=SMI_BATCH,
|
| 761 |
+
max_length=SMI_MAX_LEN,
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
smi_val_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
|
| 765 |
+
smi_val["sequence"].astype(str).str.strip().tolist(),
|
| 766 |
+
smi_tok, smi_roformer,
|
| 767 |
+
batch_size=SMI_BATCH,
|
| 768 |
+
max_length=SMI_MAX_LEN,
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
smi_train_ds = Dataset.from_dict({
|
| 772 |
+
"target_sequence": smi_train["target_sequence"].tolist(),
|
| 773 |
+
"sequence": smi_train["sequence"].tolist(),
|
| 774 |
+
"label": smi_train["label"].astype(float).tolist(),
|
| 775 |
+
"target_embedding": smi_train_tgt_emb,
|
| 776 |
+
"embedding": smi_train_pooled.astype(np.float32),
|
| 777 |
+
COL_SMI_IPTM: smi_train[COL_SMI_IPTM].astype(float).tolist(),
|
| 778 |
+
COL_AFF: smi_train[COL_AFF].astype(float).tolist(),
|
| 779 |
+
"affinity_class": smi_train["affinity_class"].tolist(),
|
| 780 |
+
})
|
| 781 |
+
|
| 782 |
+
smi_val_ds = Dataset.from_dict({
|
| 783 |
+
"target_sequence": smi_val["target_sequence"].tolist(),
|
| 784 |
+
"sequence": smi_val["sequence"].tolist(),
|
| 785 |
+
"label": smi_val["label"].astype(float).tolist(),
|
| 786 |
+
"target_embedding": smi_val_tgt_emb,
|
| 787 |
+
"embedding": smi_val_pooled.astype(np.float32),
|
| 788 |
+
COL_SMI_IPTM: smi_val[COL_SMI_IPTM].astype(float).tolist(),
|
| 789 |
+
COL_AFF: smi_val[COL_AFF].astype(float).tolist(),
|
| 790 |
+
"affinity_class": smi_val["affinity_class"].tolist(),
|
| 791 |
+
})
|
| 792 |
+
|
| 793 |
+
smi_pooled_dd = DatasetDict({"train": smi_train_ds, "val": smi_val_ds})
|
| 794 |
+
smi_pooled_out = OUT_ROOT / "pair_wt_smiles_pooled"
|
| 795 |
+
smi_pooled_dd.save_to_disk(str(smi_pooled_out))
|
| 796 |
+
log(f"Saved SMILES pooled -> {smi_pooled_out}")
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
# =========================
|
| 800 |
+
# WT unpooled paired (ESM target + ESM binder) + save
|
| 801 |
+
# =========================
|
| 802 |
+
with section("WT unpooled paired embeddings + save"):
|
| 803 |
+
wt_tok_unpooled = wt_tok # reuse tokenizer
|
| 804 |
+
wt_esm_unpooled = wt_esm # reuse model
|
| 805 |
+
|
| 806 |
+
wt_unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled"
|
| 807 |
+
wt_unpooled_dd = DatasetDict({
|
| 808 |
+
"train": build_wt_unpooled_dataset(wt_train, wt_unpooled_out / "train",
|
| 809 |
+
wt_tok_unpooled, wt_esm_unpooled),
|
| 810 |
+
"val": build_wt_unpooled_dataset(wt_val, wt_unpooled_out / "val",
|
| 811 |
+
wt_tok_unpooled, wt_esm_unpooled),
|
| 812 |
+
})
|
| 813 |
+
# (Optional) also save as DatasetDict root if you want a single load_from_disk path:
|
| 814 |
+
wt_unpooled_dd.save_to_disk(str(wt_unpooled_out))
|
| 815 |
+
log(f"Saved WT unpooled -> {wt_unpooled_out}")
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
# =========================
|
| 819 |
+
# SMILES unpooled paired (ESM target + PeptideCLM binder) + save
|
| 820 |
+
# =========================
|
| 821 |
+
with section("SMILES unpooled paired embeddings + save"):
|
| 822 |
+
# reuse already-loaded smi_tok/smi_roformer from pooled section if still in scope;
|
| 823 |
+
# otherwise re-init here:
|
| 824 |
+
# smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
|
| 825 |
+
# smi_roformer = AutoModelForMaskedLM.from_pretrained(SMI_MODEL_NAME).roformer.to(DEVICE).eval()
|
| 826 |
+
|
| 827 |
+
smi_unpooled_out = OUT_ROOT / "pair_wt_smiles_unpooled"
|
| 828 |
+
smi_unpooled_dd = DatasetDict({
|
| 829 |
+
"train": build_smiles_unpooled_paired_dataset(
|
| 830 |
+
smi_train, smi_unpooled_out / "train",
|
| 831 |
+
wt_tok, wt_esm,
|
| 832 |
+
smi_tok, smi_roformer
|
| 833 |
+
),
|
| 834 |
+
"val": build_smiles_unpooled_paired_dataset(
|
| 835 |
+
smi_val, smi_unpooled_out / "val",
|
| 836 |
+
wt_tok, wt_esm,
|
| 837 |
+
smi_tok, smi_roformer
|
| 838 |
+
),
|
| 839 |
+
})
|
| 840 |
+
smi_unpooled_dd.save_to_disk(str(smi_unpooled_out))
|
| 841 |
+
log(f"Saved SMILES unpooled -> {smi_unpooled_out}")
|
| 842 |
+
|
| 843 |
+
log(f"\n[DONE] All datasets saved under: {OUT_ROOT}")
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
if __name__ == "__main__":
|
| 847 |
+
main()
|
training_classifiers/binding_training.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
import optuna
|
| 8 |
+
from datasets import load_from_disk, DatasetDict
|
| 9 |
+
from scipy.stats import spearmanr
|
| 10 |
+
from lightning.pytorch import seed_everything
|
| 11 |
+
seed_everything(1986)
|
| 12 |
+
|
| 13 |
+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
| 17 |
+
rho = spearmanr(y_true, y_pred).correlation
|
| 18 |
+
if rho is None or np.isnan(rho):
|
| 19 |
+
return 0.0
|
| 20 |
+
return float(rho)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# -----------------------------
|
| 24 |
+
# Affinity class thresholds (final spec)
|
| 25 |
+
# High >= 9 ; Moderate 7-9 ; Low < 7
|
| 26 |
+
# 0=High, 1=Moderate, 2=Low
|
| 27 |
+
# -----------------------------
|
| 28 |
+
def affinity_to_class_tensor(y: torch.Tensor) -> torch.Tensor:
|
| 29 |
+
high = y >= 9.0
|
| 30 |
+
low = y < 7.0
|
| 31 |
+
mid = ~(high | low)
|
| 32 |
+
cls = torch.zeros_like(y, dtype=torch.long)
|
| 33 |
+
cls[mid] = 1
|
| 34 |
+
cls[low] = 2
|
| 35 |
+
return cls
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# -----------------------------
|
| 39 |
+
# Load paired DatasetDict
|
| 40 |
+
# -----------------------------
|
| 41 |
+
def load_split_paired(path: str):
|
| 42 |
+
dd = load_from_disk(path)
|
| 43 |
+
if not isinstance(dd, DatasetDict):
|
| 44 |
+
raise ValueError(f"Expected DatasetDict at {path}")
|
| 45 |
+
if "train" not in dd or "val" not in dd:
|
| 46 |
+
raise ValueError(f"DatasetDict missing train/val at {path}")
|
| 47 |
+
return dd["train"], dd["val"]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# -----------------------------
|
| 51 |
+
# Collate: pooled paired
|
| 52 |
+
# -----------------------------
|
| 53 |
+
def collate_pair_pooled(batch):
|
| 54 |
+
Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32) # (B,Ht)
|
| 55 |
+
Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32) # (B,Hb)
|
| 56 |
+
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
|
| 57 |
+
return Pt, Pb, y
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# -----------------------------
|
| 61 |
+
# Collate: unpooled paired
|
| 62 |
+
# -----------------------------
|
| 63 |
+
def collate_pair_unpooled(batch):
|
| 64 |
+
B = len(batch)
|
| 65 |
+
Ht = len(batch[0]["target_embedding"][0])
|
| 66 |
+
Hb = len(batch[0]["binder_embedding"][0])
|
| 67 |
+
Lt_max = max(int(x["target_length"]) for x in batch)
|
| 68 |
+
Lb_max = max(int(x["binder_length"]) for x in batch)
|
| 69 |
+
|
| 70 |
+
Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32)
|
| 71 |
+
Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32)
|
| 72 |
+
Mt = torch.zeros(B, Lt_max, dtype=torch.bool)
|
| 73 |
+
Mb = torch.zeros(B, Lb_max, dtype=torch.bool)
|
| 74 |
+
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
|
| 75 |
+
|
| 76 |
+
for i, x in enumerate(batch):
|
| 77 |
+
t = torch.tensor(x["target_embedding"], dtype=torch.float32)
|
| 78 |
+
b = torch.tensor(x["binder_embedding"], dtype=torch.float32)
|
| 79 |
+
lt, lb = t.shape[0], b.shape[0]
|
| 80 |
+
Pt[i, :lt] = t
|
| 81 |
+
Pb[i, :lb] = b
|
| 82 |
+
Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool)
|
| 83 |
+
Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool)
|
| 84 |
+
|
| 85 |
+
return Pt, Mt, Pb, Mb, y
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# -----------------------------
|
| 89 |
+
# Cross-attention models
|
| 90 |
+
# -----------------------------
|
| 91 |
+
class CrossAttnPooled(nn.Module):
|
| 92 |
+
"""
|
| 93 |
+
pooled vectors -> treat as single-token sequences for cross attention
|
| 94 |
+
"""
|
| 95 |
+
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
|
| 98 |
+
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
|
| 99 |
+
|
| 100 |
+
self.layers = nn.ModuleList([])
|
| 101 |
+
for _ in range(n_layers):
|
| 102 |
+
self.layers.append(nn.ModuleDict({
|
| 103 |
+
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
|
| 104 |
+
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
|
| 105 |
+
"n1t": nn.LayerNorm(hidden),
|
| 106 |
+
"n2t": nn.LayerNorm(hidden),
|
| 107 |
+
"n1b": nn.LayerNorm(hidden),
|
| 108 |
+
"n2b": nn.LayerNorm(hidden),
|
| 109 |
+
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 110 |
+
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 111 |
+
}))
|
| 112 |
+
|
| 113 |
+
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
|
| 114 |
+
self.reg = nn.Linear(hidden, 1)
|
| 115 |
+
self.cls = nn.Linear(hidden, 3)
|
| 116 |
+
|
| 117 |
+
def forward(self, t_vec, b_vec):
|
| 118 |
+
# (B,Ht),(B,Hb)
|
| 119 |
+
t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
|
| 120 |
+
b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
|
| 121 |
+
|
| 122 |
+
for L in self.layers:
|
| 123 |
+
t_attn, _ = L["attn_tb"](t, b, b)
|
| 124 |
+
t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
|
| 125 |
+
t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
|
| 126 |
+
|
| 127 |
+
b_attn, _ = L["attn_bt"](b, t, t)
|
| 128 |
+
b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
|
| 129 |
+
b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
|
| 130 |
+
|
| 131 |
+
t0 = t[0]
|
| 132 |
+
b0 = b[0]
|
| 133 |
+
z = torch.cat([t0, b0], dim=-1)
|
| 134 |
+
h = self.shared(z)
|
| 135 |
+
return self.reg(h).squeeze(-1), self.cls(h)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class CrossAttnUnpooled(nn.Module):
|
| 139 |
+
"""
|
| 140 |
+
token sequences with masks; alternating cross attention.
|
| 141 |
+
"""
|
| 142 |
+
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
|
| 145 |
+
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
|
| 146 |
+
|
| 147 |
+
self.layers = nn.ModuleList([])
|
| 148 |
+
for _ in range(n_layers):
|
| 149 |
+
self.layers.append(nn.ModuleDict({
|
| 150 |
+
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 151 |
+
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 152 |
+
"n1t": nn.LayerNorm(hidden),
|
| 153 |
+
"n2t": nn.LayerNorm(hidden),
|
| 154 |
+
"n1b": nn.LayerNorm(hidden),
|
| 155 |
+
"n2b": nn.LayerNorm(hidden),
|
| 156 |
+
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 157 |
+
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 158 |
+
}))
|
| 159 |
+
|
| 160 |
+
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
|
| 161 |
+
self.reg = nn.Linear(hidden, 1)
|
| 162 |
+
self.cls = nn.Linear(hidden, 3)
|
| 163 |
+
|
| 164 |
+
def masked_mean(self, X, M):
|
| 165 |
+
Mf = M.unsqueeze(-1).float()
|
| 166 |
+
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 167 |
+
return (X * Mf).sum(dim=1) / denom
|
| 168 |
+
|
| 169 |
+
def forward(self, T, Mt, B, Mb):
|
| 170 |
+
# T:(B,Lt,Ht), Mt:(B,Lt) ; B:(B,Lb,Hb), Mb:(B,Lb)
|
| 171 |
+
T = self.t_proj(T)
|
| 172 |
+
Bx = self.b_proj(B)
|
| 173 |
+
|
| 174 |
+
kp_t = ~Mt # key_padding_mask True = pad
|
| 175 |
+
kp_b = ~Mb
|
| 176 |
+
|
| 177 |
+
for L in self.layers:
|
| 178 |
+
# T attends to B
|
| 179 |
+
T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
|
| 180 |
+
T = L["n1t"](T + T_attn)
|
| 181 |
+
T = L["n2t"](T + L["fft"](T))
|
| 182 |
+
|
| 183 |
+
# B attends to T
|
| 184 |
+
B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
|
| 185 |
+
Bx = L["n1b"](Bx + B_attn)
|
| 186 |
+
Bx = L["n2b"](Bx + L["ffb"](Bx))
|
| 187 |
+
|
| 188 |
+
t_pool = self.masked_mean(T, Mt)
|
| 189 |
+
b_pool = self.masked_mean(Bx, Mb)
|
| 190 |
+
z = torch.cat([t_pool, b_pool], dim=-1)
|
| 191 |
+
h = self.shared(z)
|
| 192 |
+
return self.reg(h).squeeze(-1), self.cls(h)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# -----------------------------
|
| 196 |
+
# Train/eval
|
| 197 |
+
# -----------------------------
|
| 198 |
+
@torch.no_grad()
|
| 199 |
+
def eval_spearman_pooled(model, loader):
|
| 200 |
+
model.eval()
|
| 201 |
+
ys, ps = [], []
|
| 202 |
+
for t, b, y in loader:
|
| 203 |
+
t = t.to(DEVICE, non_blocking=True)
|
| 204 |
+
b = b.to(DEVICE, non_blocking=True)
|
| 205 |
+
pred, _ = model(t, b)
|
| 206 |
+
ys.append(y.numpy())
|
| 207 |
+
ps.append(pred.detach().cpu().numpy())
|
| 208 |
+
return safe_spearmanr(np.concatenate(ys), np.concatenate(ps))
|
| 209 |
+
|
| 210 |
+
@torch.no_grad()
|
| 211 |
+
def eval_spearman_unpooled(model, loader):
|
| 212 |
+
model.eval()
|
| 213 |
+
ys, ps = [], []
|
| 214 |
+
for T, Mt, B, Mb, y in loader:
|
| 215 |
+
T = T.to(DEVICE, non_blocking=True)
|
| 216 |
+
Mt = Mt.to(DEVICE, non_blocking=True)
|
| 217 |
+
B = B.to(DEVICE, non_blocking=True)
|
| 218 |
+
Mb = Mb.to(DEVICE, non_blocking=True)
|
| 219 |
+
pred, _ = model(T, Mt, B, Mb)
|
| 220 |
+
ys.append(y.numpy())
|
| 221 |
+
ps.append(pred.detach().cpu().numpy())
|
| 222 |
+
return safe_spearmanr(np.concatenate(ys), np.concatenate(ps))
|
| 223 |
+
|
| 224 |
+
def train_one_epoch_pooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0):
|
| 225 |
+
model.train()
|
| 226 |
+
for t, b, y in loader:
|
| 227 |
+
t = t.to(DEVICE, non_blocking=True)
|
| 228 |
+
b = b.to(DEVICE, non_blocking=True)
|
| 229 |
+
y = y.to(DEVICE, non_blocking=True)
|
| 230 |
+
y_cls = affinity_to_class_tensor(y)
|
| 231 |
+
|
| 232 |
+
opt.zero_grad(set_to_none=True)
|
| 233 |
+
pred, logits = model(t, b)
|
| 234 |
+
L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls)
|
| 235 |
+
L.backward()
|
| 236 |
+
if clip is not None:
|
| 237 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
|
| 238 |
+
opt.step()
|
| 239 |
+
|
| 240 |
+
def train_one_epoch_unpooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0):
|
| 241 |
+
model.train()
|
| 242 |
+
for T, Mt, B, Mb, y in loader:
|
| 243 |
+
T = T.to(DEVICE, non_blocking=True)
|
| 244 |
+
Mt = Mt.to(DEVICE, non_blocking=True)
|
| 245 |
+
B = B.to(DEVICE, non_blocking=True)
|
| 246 |
+
Mb = Mb.to(DEVICE, non_blocking=True)
|
| 247 |
+
y = y.to(DEVICE, non_blocking=True)
|
| 248 |
+
y_cls = affinity_to_class_tensor(y)
|
| 249 |
+
|
| 250 |
+
opt.zero_grad(set_to_none=True)
|
| 251 |
+
pred, logits = model(T, Mt, B, Mb)
|
| 252 |
+
L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls)
|
| 253 |
+
L.backward()
|
| 254 |
+
if clip is not None:
|
| 255 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
|
| 256 |
+
opt.step()
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# -----------------------------
|
| 260 |
+
# Optuna objective
|
| 261 |
+
# -----------------------------
|
| 262 |
+
def objective_crossattn(trial: optuna.Trial, mode: str, train_ds, val_ds) -> float:
|
| 263 |
+
lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
|
| 264 |
+
wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True)
|
| 265 |
+
dropout = trial.suggest_float("dropout", 0.0, 0.4)
|
| 266 |
+
hidden = trial.suggest_categorical("hidden_dim", [256, 384, 512, 768])
|
| 267 |
+
n_heads = trial.suggest_categorical("n_heads", [4, 8])
|
| 268 |
+
n_layers = trial.suggest_int("n_layers", 1, 4)
|
| 269 |
+
cls_w = trial.suggest_float("cls_weight", 0.1, 2.0, log=True)
|
| 270 |
+
batch = trial.suggest_categorical("batch_size", [16, 32, 64, 128])
|
| 271 |
+
|
| 272 |
+
# infer dims from first row
|
| 273 |
+
if mode == "pooled":
|
| 274 |
+
Ht = len(train_ds[0]["target_embedding"])
|
| 275 |
+
Hb = len(train_ds[0]["binder_embedding"])
|
| 276 |
+
collate = collate_pair_pooled
|
| 277 |
+
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 278 |
+
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 279 |
+
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 280 |
+
eval_fn = eval_spearman_pooled
|
| 281 |
+
train_fn = train_one_epoch_pooled
|
| 282 |
+
|
| 283 |
+
else:
|
| 284 |
+
Ht = len(train_ds[0]["target_embedding"][0])
|
| 285 |
+
Hb = len(train_ds[0]["binder_embedding"][0])
|
| 286 |
+
collate = collate_pair_unpooled
|
| 287 |
+
model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 288 |
+
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 289 |
+
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 290 |
+
eval_fn = eval_spearman_unpooled
|
| 291 |
+
train_fn = train_one_epoch_unpooled
|
| 292 |
+
|
| 293 |
+
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
|
| 294 |
+
loss_reg = nn.MSELoss()
|
| 295 |
+
loss_cls = nn.CrossEntropyLoss()
|
| 296 |
+
|
| 297 |
+
best = -1e9
|
| 298 |
+
bad = 0
|
| 299 |
+
patience = 10
|
| 300 |
+
|
| 301 |
+
for ep in range(1, 61):
|
| 302 |
+
train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w)
|
| 303 |
+
rho = eval_fn(model, val_loader)
|
| 304 |
+
|
| 305 |
+
trial.report(rho, ep)
|
| 306 |
+
if trial.should_prune():
|
| 307 |
+
raise optuna.TrialPruned()
|
| 308 |
+
|
| 309 |
+
if rho > best + 1e-6:
|
| 310 |
+
best = rho
|
| 311 |
+
bad = 0
|
| 312 |
+
else:
|
| 313 |
+
bad += 1
|
| 314 |
+
if bad >= patience:
|
| 315 |
+
break
|
| 316 |
+
|
| 317 |
+
return float(best)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# -----------------------------
|
| 321 |
+
# Run: optuna + refit best
|
| 322 |
+
# -----------------------------
|
| 323 |
+
def run(dataset_path: str, out_dir: str, mode: str, n_trials: int = 50):
|
| 324 |
+
out_dir = Path(out_dir)
|
| 325 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 326 |
+
|
| 327 |
+
train_ds, val_ds = load_split_paired(dataset_path)
|
| 328 |
+
print(f"[Data] Train={len(train_ds)} Val={len(val_ds)} | mode={mode}")
|
| 329 |
+
|
| 330 |
+
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
|
| 331 |
+
study.optimize(lambda t: objective_crossattn(t, mode, train_ds, val_ds), n_trials=n_trials)
|
| 332 |
+
|
| 333 |
+
study.trials_dataframe().to_csv(out_dir / "optuna_trials.csv", index=False)
|
| 334 |
+
best = study.best_trial
|
| 335 |
+
best_params = dict(best.params)
|
| 336 |
+
|
| 337 |
+
# refit longer
|
| 338 |
+
lr = float(best_params["lr"])
|
| 339 |
+
wd = float(best_params["weight_decay"])
|
| 340 |
+
dropout = float(best_params["dropout"])
|
| 341 |
+
hidden = int(best_params["hidden_dim"])
|
| 342 |
+
n_heads = int(best_params["n_heads"])
|
| 343 |
+
n_layers = int(best_params["n_layers"])
|
| 344 |
+
cls_w = float(best_params["cls_weight"])
|
| 345 |
+
batch = int(best_params["batch_size"])
|
| 346 |
+
|
| 347 |
+
loss_reg = nn.MSELoss()
|
| 348 |
+
loss_cls = nn.CrossEntropyLoss()
|
| 349 |
+
|
| 350 |
+
if mode == "pooled":
|
| 351 |
+
Ht = len(train_ds[0]["target_embedding"])
|
| 352 |
+
Hb = len(train_ds[0]["binder_embedding"])
|
| 353 |
+
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 354 |
+
collate = collate_pair_pooled
|
| 355 |
+
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 356 |
+
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 357 |
+
eval_fn = eval_spearman_pooled
|
| 358 |
+
train_fn = train_one_epoch_pooled
|
| 359 |
+
else:
|
| 360 |
+
Ht = len(train_ds[0]["target_embedding"][0])
|
| 361 |
+
Hb = len(train_ds[0]["binder_embedding"][0])
|
| 362 |
+
model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 363 |
+
collate = collate_pair_unpooled
|
| 364 |
+
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 365 |
+
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 366 |
+
eval_fn = eval_spearman_unpooled
|
| 367 |
+
train_fn = train_one_epoch_unpooled
|
| 368 |
+
|
| 369 |
+
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
|
| 370 |
+
|
| 371 |
+
best_rho = -1e9
|
| 372 |
+
bad = 0
|
| 373 |
+
patience = 20
|
| 374 |
+
best_state = None
|
| 375 |
+
|
| 376 |
+
for ep in range(1, 201):
|
| 377 |
+
train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w)
|
| 378 |
+
rho = eval_fn(model, val_loader)
|
| 379 |
+
|
| 380 |
+
if rho > best_rho + 1e-6:
|
| 381 |
+
best_rho = rho
|
| 382 |
+
bad = 0
|
| 383 |
+
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
| 384 |
+
else:
|
| 385 |
+
bad += 1
|
| 386 |
+
if bad >= patience:
|
| 387 |
+
break
|
| 388 |
+
|
| 389 |
+
if best_state is not None:
|
| 390 |
+
model.load_state_dict(best_state)
|
| 391 |
+
|
| 392 |
+
# save
|
| 393 |
+
torch.save({"mode": mode, "best_params": best_params, "state_dict": model.state_dict()}, out_dir / "best_model.pt")
|
| 394 |
+
with open(out_dir / "best_params.json", "w") as f:
|
| 395 |
+
json.dump(best_params, f, indent=2)
|
| 396 |
+
|
| 397 |
+
print(f"[DONE] {out_dir} | best_optuna_rho={study.best_value:.4f} | refit_best_rho={best_rho:.4f}")
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
if __name__ == "__main__":
|
| 401 |
+
import argparse
|
| 402 |
+
ap = argparse.ArgumentParser()
|
| 403 |
+
ap.add_argument("--dataset_path", type=str, required=True, help="Paired DatasetDict path (pair_*)")
|
| 404 |
+
ap.add_argument("--mode", type=str, choices=["pooled", "unpooled"], required=True)
|
| 405 |
+
ap.add_argument("--out_dir", type=str, required=True)
|
| 406 |
+
ap.add_argument("--n_trials", type=int, default=50)
|
| 407 |
+
args = ap.parse_args()
|
| 408 |
+
|
| 409 |
+
run(
|
| 410 |
+
dataset_path=args.dataset_path,
|
| 411 |
+
out_dir=args.out_dir,
|
| 412 |
+
mode=args.mode,
|
| 413 |
+
n_trials=args.n_trials,
|
| 414 |
+
)
|
training_classifiers/binding_wt.bash
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=b-data
|
| 3 |
+
#SBATCH --partition=dgx-b200
|
| 4 |
+
#SBATCH --gpus=1
|
| 5 |
+
#SBATCH --cpus-per-task=10
|
| 6 |
+
#SBATCH --mem=100G
|
| 7 |
+
#SBATCH --time=48:00:00
|
| 8 |
+
#SBATCH --output=%x_%j.out
|
| 9 |
+
|
| 10 |
+
HOME_LOC=/vast/projects/pranam/lab/yz927
|
| 11 |
+
SCRIPT_LOC=$HOME_LOC/projects/Classifier_Weight/training_classifiers
|
| 12 |
+
DATA_LOC=$HOME_LOC/projects/Classifier_Weight/training_data_cleaned
|
| 13 |
+
OBJECTIVE='binding_affinity'
|
| 14 |
+
WT='smiles' #wt/smiles
|
| 15 |
+
STATUS='pooled' #pooled/unpooled
|
| 16 |
+
DATA_FILE="pair_wt_${WT}_${STATUS}"
|
| 17 |
+
LOG_LOC=$SCRIPT_LOC
|
| 18 |
+
DATE=$(date +%m_%d)
|
| 19 |
+
SPECIAL_PREFIX="binding_affinity_data_generation"
|
| 20 |
+
|
| 21 |
+
# Create log directory if it doesn't exist
|
| 22 |
+
mkdir -p $LOG_LOC
|
| 23 |
+
|
| 24 |
+
cd $SCRIPT_LOC
|
| 25 |
+
source /vast/projects/pranam/lab/shared/miniconda3/etc/profile.d/conda.sh
|
| 26 |
+
conda activate /vast/projects/pranam/lab/shared/miniconda3/envs/metal
|
| 27 |
+
|
| 28 |
+
python -u binding_affinity_split.py > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}.log" 2>&1
|
| 29 |
+
|
| 30 |
+
echo "Script completed at $(date)"
|
| 31 |
+
conda deactivate
|
training_classifiers/hemolysis/cnn_smiles/best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dd8f379ef2a10dacff4236ca37aa64832a3ce8bc9608ca1297b1b7662780ee6f
|
| 3 |
+
size 14170677
|
training_classifiers/hemolysis/cnn_smiles/best_model_benchmark.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"n_samples": 800,
|
| 3 |
+
"wall_time_s": 6.477548930000921,
|
| 4 |
+
"throughput_samples_per_s": 123.50350551499211,
|
| 5 |
+
"gpu_total_kernel_ms": 39.01705604791641,
|
| 6 |
+
"gpu_ms_per_sample": 0.04877132005989551,
|
| 7 |
+
"gpu_avg_ms_per_batch": 0.7803411209583282,
|
| 8 |
+
"gpu_peak_mem_MB": 219.23974609375,
|
| 9 |
+
"telemetry_pre": {
|
| 10 |
+
"cpu_freq_current_MHz": 1357.0153303571428,
|
| 11 |
+
"cpu_freq_max_MHz": 4000.0,
|
| 12 |
+
"cpu_util_pct": 5.7,
|
| 13 |
+
"cpu_count_logical": 224,
|
| 14 |
+
"cpu_count_physical": 112,
|
| 15 |
+
"gpu_util_pct": 0,
|
| 16 |
+
"gpu_mem_util_pct": 0,
|
| 17 |
+
"gpu_mem_used_MB": 1829.0625,
|
| 18 |
+
"gpu_mem_total_MB": 183359.0,
|
| 19 |
+
"gpu_sm_clock_MHz": 1965,
|
| 20 |
+
"gpu_mem_clock_MHz": 3996,
|
| 21 |
+
"gpu_power_W": 194.631,
|
| 22 |
+
"gpu_temp_C": 31
|
| 23 |
+
},
|
| 24 |
+
"telemetry_post": {
|
| 25 |
+
"cpu_freq_current_MHz": 1529.574044642856,
|
| 26 |
+
"cpu_freq_max_MHz": 4000.0,
|
| 27 |
+
"cpu_util_pct": 6.0,
|
| 28 |
+
"cpu_count_logical": 224,
|
| 29 |
+
"cpu_count_physical": 112,
|
| 30 |
+
"gpu_util_pct": 0,
|
| 31 |
+
"gpu_mem_util_pct": 0,
|
| 32 |
+
"gpu_mem_used_MB": 1923.0625,
|
| 33 |
+
"gpu_mem_total_MB": 183359.0,
|
| 34 |
+
"gpu_sm_clock_MHz": 1965,
|
| 35 |
+
"gpu_mem_clock_MHz": 3996,
|
| 36 |
+
"gpu_power_W": 197.518,
|
| 37 |
+
"gpu_temp_C": 31
|
| 38 |
+
}
|
| 39 |
+
}
|
training_classifiers/hemolysis/cnn_smiles/optimization_summary.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
========================================================================
|
| 2 |
+
MODEL: cnn
|
| 3 |
+
Best Optuna F1 (objective): 0.5290
|
| 4 |
+
Best Optuna AUC (val, recorded): 0.7477
|
| 5 |
+
Best Optuna threshold (val): 0.4518
|
| 6 |
+
Refit best AUC (val): 0.7851
|
| 7 |
+
Refit best F1@thr (val): 0.5366 at thr=0.5298
|
| 8 |
+
Best params:
|
| 9 |
+
{
|
| 10 |
+
"lr": 0.0002237456677696451,
|
| 11 |
+
"weight_decay": 0.0005722918417016266,
|
| 12 |
+
"dropout": 0.2697397384794115,
|
| 13 |
+
"batch_size": 16,
|
| 14 |
+
"channels": 512,
|
| 15 |
+
"kernel": 3,
|
| 16 |
+
"layers": 4
|
| 17 |
+
}
|
| 18 |
+
Saved model: /vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_classifiers/hemolysis/cnn_smiles/best_model.pt
|
| 19 |
+
========================================================================
|
training_classifiers/hemolysis/cnn_smiles/pr_curve.png
ADDED
|
training_classifiers/hemolysis/cnn_smiles/roc_curve.png
ADDED
|
training_classifiers/hemolysis/cnn_smiles/study_trials.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:341f45fdbd7565793d9ff64a3d1be6ebd6d56d6e6957db19d119a946c658d296
|
| 3 |
+
size 48177
|
training_classifiers/hemolysis/cnn_smiles/train_predictions.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:44ee47243c96b0d372586ee8c40582af2aa086cfb626c5826b05778f1f08b919
|
| 3 |
+
size 1943431
|
training_classifiers/hemolysis/cnn_smiles/val_predictions.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:df88e6577184fbf56018945e0540c5d5679981ff74af08422b03cd6aab43d5e3
|
| 3 |
+
size 472104
|