ynuozhang
commited on
Commit
·
1c0b97e
1
Parent(s):
baf3373
Delete unused embeddings, metrics, and training scripts
Browse files- .gitattributes +1 -0
- embeddings/binding/data-00000-of-00001.arrow +0 -3
- embeddings/binding/dataset_info.json +0 -3
- embeddings/binding/state.json +0 -3
- embeddings/fast_embedding_generation.py +0 -3
- embeddings/hemolysis/data-00000-of-00001.arrow +0 -3
- embeddings/hemolysis/dataset_info.json +0 -3
- embeddings/hemolysis/shuffled_hemo.csv +0 -3
- embeddings/hemolysis/state.json +0 -3
- embeddings/nonfouling/combined_nonfouling.csv +0 -3
- embeddings/nonfouling/data-00000-of-00001.arrow +0 -3
- embeddings/nonfouling/dataset_info.json +0 -3
- embeddings/nonfouling/state.json +0 -3
- embeddings/permeability/data-00000-of-00001.arrow +0 -3
- embeddings/permeability/dataset_info.json +0 -3
- embeddings/permeability/nc-CPP-processed.csv +0 -3
- embeddings/permeability/state.json +0 -3
- embeddings/solubility/data-00000-of-00001.arrow +0 -3
- embeddings/solubility/dataset_info.json +0 -3
- embeddings/solubility/shuffled_sol.csv +0 -3
- embeddings/solubility/state.json +0 -3
- metrics/binding/best_model_val_correlation.png +0 -3
- metrics/binding/binding_train_correlation.png +0 -3
- metrics/hemolysis/optimization_metrics.txt +0 -3
- metrics/hemolysis/train_classification_plot.png +0 -3
- metrics/hemolysis/train_predictions_binary.csv +0 -3
- metrics/hemolysis/val_classification_plot.png +0 -3
- metrics/hemolysis/val_predictions_binary.csv +0 -3
- metrics/nonfouling/optimization_metrics.txt +0 -22
- metrics/nonfouling/train_classification_plot.png +0 -0
- metrics/nonfouling/train_predictions_binary.csv +0 -3
- metrics/nonfouling/val_classification_plot.png +0 -0
- metrics/nonfouling/val_predictions_binary.csv +0 -3
- metrics/permeability/optimization_metrics.txt +0 -3
- metrics/permeability/train_correlation.png +0 -3
- metrics/permeability/train_predictions.csv +0 -3
- metrics/permeability/val_correlation.png +0 -3
- metrics/permeability/val_predictions.csv +0 -3
- metrics/solubility/optimization_metrics.txt +0 -3
- metrics/solubility/train_classification_plot.png +0 -3
- metrics/solubility/train_predictions_binary.csv +0 -3
- metrics/solubility/val_classification_plot.png +0 -3
- metrics/solubility/val_predictions_binary.csv +0 -3
- train/binary_xg.py +0 -223
- train/binding_affinity_model_clean.ipynb +0 -0
- train/binding_utils.py +0 -291
- train/permeability_xg.py +0 -186
.gitattributes
CHANGED
|
@@ -88,3 +88,4 @@ training_data filter=lfs diff=lfs merge=lfs -text
|
|
| 88 |
README.md filter=lfs diff=lfs merge=lfs -text
|
| 89 |
embeddings filter=lfs diff=lfs merge=lfs -text
|
| 90 |
models/binding_affinity_for_smiles.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 88 |
README.md filter=lfs diff=lfs merge=lfs -text
|
| 89 |
embeddings filter=lfs diff=lfs merge=lfs -text
|
| 90 |
models/binding_affinity_for_smiles.pt filter=lfs diff=lfs merge=lfs -text
|
| 91 |
+
*.csv filter=lfs diff=lfs merge=lfs -text
|
embeddings/binding/data-00000-of-00001.arrow
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:d9b08ce28b452e9767dfc7c60bd6285421bdc6b791150a5f55158da89c7bda4f
|
| 3 |
-
size 15746448
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/binding/dataset_info.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:bbe033d1c9b2ec182afa2b1682a4c41c4fdc0cd548c08d7a21c4364dc68b3595
|
| 3 |
-
size 784
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/binding/state.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:c2b6d15e6c6cf1f18c4f4cdb0ea339adaf7083a1c754266b8ad6a6468484f693
|
| 3 |
-
size 247
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/fast_embedding_generation.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:0396104ebf1dc28b0d297bdfebede1927aba9a23417c9f06cd8f39d999d099d3
|
| 3 |
-
size 3900
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/hemolysis/data-00000-of-00001.arrow
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:bef85bc99bc3c81c99fe290c0b2ef6b0d43f50c0089c59be7bf24219dd428d05
|
| 3 |
-
size 20965576
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/hemolysis/dataset_info.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:a0708b61d5f0aa0a11b62205cbd08b504ad6957c271ff3b984c3e3b9457ce9bf
|
| 3 |
-
size 370
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/hemolysis/shuffled_hemo.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:994d4faa1f29c59cb7fa9d18faa8605766b4033ae2bc432ea76e0bbdf4876b29
|
| 3 |
-
size 2236370
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/hemolysis/state.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:17a8a2f0f7a7c272396e10572248a45faeaae1c6390ef311b551ab97f5eb72b8
|
| 3 |
-
size 247
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/nonfouling/combined_nonfouling.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:90f770673946d37f5c8362b3956c789f62ec71a514ab8c46bb2523b3b3d5be2e
|
| 3 |
-
size 28623153
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/nonfouling/data-00000-of-00001.arrow
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:44c869d7a6a03ccdf9143c4007bc85c996146e12e40da3efbd455eaf49eba016
|
| 3 |
-
size 81645736
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/nonfouling/dataset_info.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:2a2517b8ba7deb4b26855d073360ab0c6a38c2780bc02211c03d2f51f9ccbb00
|
| 3 |
-
size 368
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/nonfouling/state.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:c5bf4d95c1fd57f172217a453408d1339a87f7eca5d02f4cd718bfdde6b519fb
|
| 3 |
-
size 247
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/permeability/data-00000-of-00001.arrow
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:82e749eafb2e903ef2dc47255dbe4e489e6db8055b3ba6af4c876d9b1a0f1b38
|
| 3 |
-
size 22250496
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/permeability/dataset_info.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:a0708b61d5f0aa0a11b62205cbd08b504ad6957c271ff3b984c3e3b9457ce9bf
|
| 3 |
-
size 370
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/permeability/nc-CPP-processed.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:f0bbf79d5a78023460de72087bbadd8d1f5b841b21b05b2d148d9ccd7fc3a254
|
| 3 |
-
size 1083167
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/permeability/state.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:e7a5a8e0db7a08e1c26e0115a0c3225d1850b620d00ab24b42769d78ee6208fd
|
| 3 |
-
size 247
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/solubility/data-00000-of-00001.arrow
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:36ac428037f8d09d1f45fcd6a61517428c4409638d63230b3ff1d375bdd0e5cb
|
| 3 |
-
size 106655176
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/solubility/dataset_info.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:a0708b61d5f0aa0a11b62205cbd08b504ad6957c271ff3b984c3e3b9457ce9bf
|
| 3 |
-
size 370
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/solubility/shuffled_sol.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:dc5a2139e52a6deb8a1eda45eb5f8b8abfd260724dbbd274a824a9aecb929890
|
| 3 |
-
size 49775729
|
|
|
|
|
|
|
|
|
|
|
|
embeddings/solubility/state.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:403b965b3b855f52632a00e828046dd90749edc7808ef4725d54606673f1d5cb
|
| 3 |
-
size 247
|
|
|
|
|
|
|
|
|
|
|
|
metrics/binding/best_model_val_correlation.png
DELETED
Git LFS Details
|
metrics/binding/binding_train_correlation.png
DELETED
Git LFS Details
|
metrics/hemolysis/optimization_metrics.txt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:2dc387c4d486b3c94a4237ceb04d8f04e18d9198e0b3c069e04c940b459f5422
|
| 3 |
-
size 612
|
|
|
|
|
|
|
|
|
|
|
|
metrics/hemolysis/train_classification_plot.png
DELETED
Git LFS Details
|
metrics/hemolysis/train_predictions_binary.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:5604b65e9f922252ccc92f1bbe9be33270245ada96fef2f3e92dc50c30a5487a
|
| 3 |
-
size 87035
|
|
|
|
|
|
|
|
|
|
|
|
metrics/hemolysis/val_classification_plot.png
DELETED
Git LFS Details
|
metrics/hemolysis/val_predictions_binary.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:00aca5ad646502e56930f2473c596b1cf754dd96ec13c7d2b630eb2e8cec3d95
|
| 3 |
-
size 21607
|
|
|
|
|
|
|
|
|
|
|
|
metrics/nonfouling/optimization_metrics.txt
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
============================================================
|
| 3 |
-
OPTIMIZATION COMPLETE
|
| 4 |
-
============================================================
|
| 5 |
-
Number of finished trials: 200
|
| 6 |
-
|
| 7 |
-
Best Trial: #52
|
| 8 |
-
Best F1 Score: 0.8774
|
| 9 |
-
Best AUC Score: 0.9327
|
| 10 |
-
Optuna Best Trial Value: 0.8774
|
| 11 |
-
|
| 12 |
-
Best hyperparameters:
|
| 13 |
-
lambda: 3.1278404540677405e-06
|
| 14 |
-
alpha: 2.865349682111457
|
| 15 |
-
colsample_bytree: 0.6388434847100901
|
| 16 |
-
subsample: 0.975052331668336
|
| 17 |
-
learning_rate: 0.1046988967097677
|
| 18 |
-
max_depth: 5
|
| 19 |
-
min_child_weight: 283
|
| 20 |
-
gamma: 0.7863860752901305
|
| 21 |
-
num_boost_round: 876
|
| 22 |
-
============================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics/nonfouling/train_classification_plot.png
DELETED
|
Binary file (32.9 kB)
|
|
|
metrics/nonfouling/train_predictions_binary.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:07931d9e8b63b4284bc7ce77cdcb719ea3b81d16125e14c277fce06c0e3d6b1d
|
| 3 |
-
size 219852
|
|
|
|
|
|
|
|
|
|
|
|
metrics/nonfouling/val_classification_plot.png
DELETED
|
Binary file (39.2 kB)
|
|
|
metrics/nonfouling/val_predictions_binary.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:ef9845cf64c142ff16fc915402953a1383e36ecb1c76b6174fae75c0dec59cd4
|
| 3 |
-
size 54904
|
|
|
|
|
|
|
|
|
|
|
|
metrics/permeability/optimization_metrics.txt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:aa91d37a6e116d4ff44eb2606e439f3243ddbeed613abc567128dd8112901e0c
|
| 3 |
-
size 676
|
|
|
|
|
|
|
|
|
|
|
|
metrics/permeability/train_correlation.png
DELETED
Git LFS Details
|
metrics/permeability/train_predictions.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:bd3817ce0aa54ce50cb3d56c83ae2a3157cd4d4deca7af743f9e24d2b0bfb675
|
| 3 |
-
size 89716
|
|
|
|
|
|
|
|
|
|
|
|
metrics/permeability/val_correlation.png
DELETED
Git LFS Details
|
metrics/permeability/val_predictions.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:26e268ec0f9b5290ecc7b76c9b41e5819e4bc45bb1bb29843f2ae9b781ecab46
|
| 3 |
-
size 22487
|
|
|
|
|
|
|
|
|
|
|
|
metrics/solubility/optimization_metrics.txt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:d935afa867779b07a7c99e51c12cdcd49828dd61e4be1ef357926e280fadf648
|
| 3 |
-
size 612
|
|
|
|
|
|
|
|
|
|
|
|
metrics/solubility/train_classification_plot.png
DELETED
Git LFS Details
|
metrics/solubility/train_predictions_binary.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:852f71fe2fe196274e25866bfbea03874c7e75fec923556c96536ada85a71c7c
|
| 3 |
-
size 244268
|
|
|
|
|
|
|
|
|
|
|
|
metrics/solubility/val_classification_plot.png
DELETED
Git LFS Details
|
metrics/solubility/val_predictions_binary.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:8489b5b047b630571f3959d96e5fbb003f923854e7db777dd266331f586d115d
|
| 3 |
-
size 61087
|
|
|
|
|
|
|
|
|
|
|
|
train/binary_xg.py
DELETED
|
@@ -1,223 +0,0 @@
|
|
| 1 |
-
import pandas as pd
|
| 2 |
-
import numpy as np
|
| 3 |
-
import torch
|
| 4 |
-
from sklearn.model_selection import train_test_split
|
| 5 |
-
from sklearn.metrics import precision_recall_curve, f1_score
|
| 6 |
-
import optuna
|
| 7 |
-
from optuna.trial import TrialState
|
| 8 |
-
import xgboost as xgb
|
| 9 |
-
import os
|
| 10 |
-
from datasets import load_from_disk
|
| 11 |
-
from lightning.pytorch import seed_everything
|
| 12 |
-
from rdkit import Chem, rdBase, DataStructs
|
| 13 |
-
from typing import List
|
| 14 |
-
from rdkit.Chem import AllChem
|
| 15 |
-
import matplotlib.pyplot as plt
|
| 16 |
-
from sklearn.metrics import accuracy_score, roc_auc_score
|
| 17 |
-
|
| 18 |
-
base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
|
| 19 |
-
|
| 20 |
-
def save_and_plot_binary_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, threshold, output_path):
|
| 21 |
-
"""
|
| 22 |
-
Saves the true and predicted values for training and validation sets, and generates binary classification plots.
|
| 23 |
-
|
| 24 |
-
Parameters:
|
| 25 |
-
y_true_train (array): True labels for the training set.
|
| 26 |
-
y_pred_train (array): Predicted probabilities for the training set.
|
| 27 |
-
y_true_val (array): True labels for the validation set.
|
| 28 |
-
y_pred_val (array): Predicted probabilities for the validation set.
|
| 29 |
-
threshold (float): Classification threshold for predictions.
|
| 30 |
-
output_path (str): Directory to save the CSV files and plots.
|
| 31 |
-
"""
|
| 32 |
-
os.makedirs(output_path, exist_ok=True)
|
| 33 |
-
|
| 34 |
-
# Convert probabilities to binary predictions
|
| 35 |
-
y_pred_train_binary = (y_pred_train >= threshold).astype(int)
|
| 36 |
-
y_pred_val_binary = (y_pred_val >= threshold).astype(int)
|
| 37 |
-
|
| 38 |
-
# Save training predictions
|
| 39 |
-
train_df = pd.DataFrame({
|
| 40 |
-
'True Label': y_true_train,
|
| 41 |
-
'Predicted Probability': y_pred_train,
|
| 42 |
-
'Predicted Label': y_pred_train_binary
|
| 43 |
-
})
|
| 44 |
-
train_df.to_csv(os.path.join(output_path, 'train_predictions_binary.csv'), index=False)
|
| 45 |
-
|
| 46 |
-
# Save validation predictions
|
| 47 |
-
val_df = pd.DataFrame({
|
| 48 |
-
'True Label': y_true_val,
|
| 49 |
-
'Predicted Probability': y_pred_val,
|
| 50 |
-
'Predicted Label': y_pred_val_binary
|
| 51 |
-
})
|
| 52 |
-
val_df.to_csv(os.path.join(output_path, 'val_predictions_binary.csv'), index=False)
|
| 53 |
-
|
| 54 |
-
# Plot training predictions
|
| 55 |
-
plot_binary_correlation(
|
| 56 |
-
y_true_train,
|
| 57 |
-
y_pred_train,
|
| 58 |
-
threshold,
|
| 59 |
-
title="Training Set Binary Classification Plot",
|
| 60 |
-
output_file=os.path.join(output_path, 'train_classification_plot.png')
|
| 61 |
-
)
|
| 62 |
-
|
| 63 |
-
# Plot validation predictions
|
| 64 |
-
plot_binary_correlation(
|
| 65 |
-
y_true_val,
|
| 66 |
-
y_pred_val,
|
| 67 |
-
threshold,
|
| 68 |
-
title="Validation Set Binary Classification Plot",
|
| 69 |
-
output_file=os.path.join(output_path, 'val_classification_plot.png')
|
| 70 |
-
)
|
| 71 |
-
|
| 72 |
-
def plot_binary_correlation(y_true, y_pred, threshold, title, output_file):
|
| 73 |
-
"""
|
| 74 |
-
Generates a scatter plot for binary classification and saves it to a file.
|
| 75 |
-
|
| 76 |
-
Parameters:
|
| 77 |
-
y_true (array): True labels.
|
| 78 |
-
y_pred (array): Predicted probabilities.
|
| 79 |
-
threshold (float): Classification threshold for predictions.
|
| 80 |
-
title (str): Title of the plot.
|
| 81 |
-
output_file (str): Path to save the plot.
|
| 82 |
-
"""
|
| 83 |
-
# Scatter plot
|
| 84 |
-
plt.figure(figsize=(10, 8))
|
| 85 |
-
plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF')
|
| 86 |
-
|
| 87 |
-
# Add threshold line
|
| 88 |
-
plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}')
|
| 89 |
-
|
| 90 |
-
# Add annotations
|
| 91 |
-
plt.title(title)
|
| 92 |
-
plt.xlabel("True Labels")
|
| 93 |
-
plt.ylabel("Predicted Probability")
|
| 94 |
-
plt.legend()
|
| 95 |
-
|
| 96 |
-
# Save and show the plot
|
| 97 |
-
plt.tight_layout()
|
| 98 |
-
plt.savefig(output_file)
|
| 99 |
-
plt.show()
|
| 100 |
-
|
| 101 |
-
seed_everything(42)
|
| 102 |
-
|
| 103 |
-
dataset = load_from_disk(f'{base_path}/data/solubility')
|
| 104 |
-
|
| 105 |
-
sequences = np.stack(dataset['sequence']) # Ensure sequences are SMILES strings
|
| 106 |
-
labels = np.stack(dataset['labels'])
|
| 107 |
-
embeddings = np.stack(dataset['embedding'])
|
| 108 |
-
|
| 109 |
-
# Initialize best F1 score and model path
|
| 110 |
-
best_f1 = -np.inf
|
| 111 |
-
best_model_path = f"{base_path}/src/solubility"
|
| 112 |
-
|
| 113 |
-
# Trial callback
|
| 114 |
-
def trial_info_callback(study, trial):
|
| 115 |
-
if study.best_trial == trial:
|
| 116 |
-
print(f"Trial {trial.number}:")
|
| 117 |
-
print(f" Weighted F1 Score: {trial.value}")
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def objective(trial):
|
| 122 |
-
# Define hyperparameters
|
| 123 |
-
params = {
|
| 124 |
-
'objective': 'binary:logistic',
|
| 125 |
-
'lambda': trial.suggest_float('lambda', 1e-8, 50.0, log=True),
|
| 126 |
-
'alpha': trial.suggest_float('alpha', 1e-8, 50.0, log=True),
|
| 127 |
-
'colsample_bytree': trial.suggest_float('colsample_bytree', 0.3, 1.0),
|
| 128 |
-
'subsample': trial.suggest_float('subsample', 0.5, 1.0),
|
| 129 |
-
'learning_rate': trial.suggest_float('learning_rate', 0.001, 0.3),
|
| 130 |
-
'max_depth': trial.suggest_int('max_depth', 2, 15),
|
| 131 |
-
'min_child_weight': trial.suggest_int('min_child_weight', 1, 500),
|
| 132 |
-
'gamma': trial.suggest_float('gamma', 0, 10.0),
|
| 133 |
-
'tree_method': 'hist',
|
| 134 |
-
'device': 'cuda:6',
|
| 135 |
-
}
|
| 136 |
-
|
| 137 |
-
# Suggest number of boosting rounds
|
| 138 |
-
num_boost_round = trial.suggest_int('num_boost_round', 10, 1000)
|
| 139 |
-
threshold = 0.5 # Initial classification threshold
|
| 140 |
-
|
| 141 |
-
# Split the data
|
| 142 |
-
train_idx, val_idx = train_test_split(
|
| 143 |
-
np.arange(len(sequences)), test_size=0.2, stratify=labels, random_state=42
|
| 144 |
-
)
|
| 145 |
-
train_subset = dataset.select(train_idx).with_format("torch")
|
| 146 |
-
val_subset = dataset.select(val_idx).with_format("torch")
|
| 147 |
-
|
| 148 |
-
# Extract embeddings and labels for train/validation
|
| 149 |
-
train_embeddings = np.array(train_subset['embedding'])
|
| 150 |
-
valid_embeddings = np.array(val_subset['embedding'])
|
| 151 |
-
train_labels = np.array(train_subset['labels'])
|
| 152 |
-
valid_labels = np.array(val_subset['labels'])
|
| 153 |
-
|
| 154 |
-
# Prepare training and validation sets
|
| 155 |
-
dtrain = xgb.DMatrix(train_embeddings, label=train_labels)
|
| 156 |
-
dvalid = xgb.DMatrix(valid_embeddings, label=valid_labels)
|
| 157 |
-
|
| 158 |
-
# Train the model
|
| 159 |
-
model = xgb.train(
|
| 160 |
-
params=params,
|
| 161 |
-
dtrain=dtrain,
|
| 162 |
-
num_boost_round=num_boost_round,
|
| 163 |
-
evals=[(dvalid, "validation")],
|
| 164 |
-
early_stopping_rounds=50,
|
| 165 |
-
verbose_eval=False,
|
| 166 |
-
)
|
| 167 |
-
|
| 168 |
-
# Predict probabilities
|
| 169 |
-
preds_train = model.predict(dtrain)
|
| 170 |
-
preds_val = model.predict(dvalid)
|
| 171 |
-
|
| 172 |
-
# Calculate metrics
|
| 173 |
-
f1_val = f1_score(valid_labels, (preds_val >= threshold).astype(int), average="weighted")
|
| 174 |
-
auc_val = roc_auc_score(valid_labels, preds_val)
|
| 175 |
-
print(f"Trial {trial.number}: AUC: {auc_val:.3f}, F1 Score: {f1_val:.3f}")
|
| 176 |
-
|
| 177 |
-
# Save the model if it has the best F1 score
|
| 178 |
-
current_best = trial.study.user_attrs.get("best_f1", -np.inf)
|
| 179 |
-
if f1_val > current_best:
|
| 180 |
-
trial.study.set_user_attr("best_f1", f1_val)
|
| 181 |
-
trial.study.set_user_attr("best_auc", auc_val)
|
| 182 |
-
trial.study.set_user_attr("best_trial", trial.number)
|
| 183 |
-
os.makedirs(best_model_path, exist_ok=True)
|
| 184 |
-
|
| 185 |
-
# Save the model
|
| 186 |
-
model.save_model(os.path.join(best_model_path, "best_model_f1.json"))
|
| 187 |
-
print(f"✓ NEW BEST! Trial {trial.number}: F1={f1_val:.4f}, AUC={auc_val:.4f} - Model saved!")
|
| 188 |
-
|
| 189 |
-
# Save and plot binary predictions
|
| 190 |
-
save_and_plot_binary_predictions(
|
| 191 |
-
train_labels, preds_train, valid_labels, preds_val, threshold, best_model_path
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
return f1_val
|
| 195 |
-
|
| 196 |
-
if __name__ == "__main__":
|
| 197 |
-
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
|
| 198 |
-
study.optimize(objective, n_trials=200)
|
| 199 |
-
|
| 200 |
-
# Prepare summary text
|
| 201 |
-
summary = []
|
| 202 |
-
summary.append("\n" + "="*60)
|
| 203 |
-
summary.append("OPTIMIZATION COMPLETE")
|
| 204 |
-
summary.append("="*60)
|
| 205 |
-
summary.append(f"Number of finished trials: {len(study.trials)}")
|
| 206 |
-
summary.append(f"\nBest Trial: #{study.user_attrs.get('best_trial', 'N/A')}")
|
| 207 |
-
summary.append(f"Best F1 Score: {study.user_attrs.get('best_f1', None):.4f}")
|
| 208 |
-
summary.append(f"Best AUC Score: {study.user_attrs.get('best_auc', None):.4f}")
|
| 209 |
-
summary.append(f"Optuna Best Trial Value: {study.best_trial.value:.4f}")
|
| 210 |
-
summary.append(f"\nBest hyperparameters:")
|
| 211 |
-
for key, value in study.best_trial.params.items():
|
| 212 |
-
summary.append(f" {key}: {value}")
|
| 213 |
-
summary.append("="*60)
|
| 214 |
-
|
| 215 |
-
# Print to console
|
| 216 |
-
for line in summary:
|
| 217 |
-
print(line)
|
| 218 |
-
|
| 219 |
-
# Save to file
|
| 220 |
-
metrics_file = os.path.join(best_model_path, "optimization_metrics.txt")
|
| 221 |
-
with open(metrics_file, 'w') as f:
|
| 222 |
-
f.write('\n'.join(summary))
|
| 223 |
-
print(f"\n✓ Metrics saved to: {metrics_file}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train/binding_affinity_model_clean.ipynb
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
train/binding_utils.py
DELETED
|
@@ -1,291 +0,0 @@
|
|
| 1 |
-
from torch import nn
|
| 2 |
-
import pdb
|
| 3 |
-
import torch
|
| 4 |
-
import numpy as np
|
| 5 |
-
|
| 6 |
-
def to_var(x):
|
| 7 |
-
if torch.cuda.is_available():
|
| 8 |
-
x = x.cuda()
|
| 9 |
-
return x
|
| 10 |
-
|
| 11 |
-
class MultiHeadAttentionSequence(nn.Module):
|
| 12 |
-
|
| 13 |
-
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
| 14 |
-
|
| 15 |
-
super().__init__()
|
| 16 |
-
|
| 17 |
-
self.n_head = n_head
|
| 18 |
-
self.d_model = d_model
|
| 19 |
-
self.d_k = d_k
|
| 20 |
-
self.d_v = d_v
|
| 21 |
-
|
| 22 |
-
self.W_Q = nn.Linear(d_model, n_head*d_k)
|
| 23 |
-
self.W_K = nn.Linear(d_model, n_head*d_k)
|
| 24 |
-
self.W_V = nn.Linear(d_model, n_head*d_v)
|
| 25 |
-
self.W_O = nn.Linear(n_head*d_v, d_model)
|
| 26 |
-
|
| 27 |
-
self.layer_norm = nn.LayerNorm(d_model)
|
| 28 |
-
|
| 29 |
-
self.dropout = nn.Dropout(dropout)
|
| 30 |
-
|
| 31 |
-
def forward(self, q, k, v):
|
| 32 |
-
|
| 33 |
-
batch, len_q, _ = q.size()
|
| 34 |
-
batch, len_k, _ = k.size()
|
| 35 |
-
batch, len_v, _ = v.size()
|
| 36 |
-
|
| 37 |
-
Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
|
| 38 |
-
K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
|
| 39 |
-
V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
|
| 40 |
-
|
| 41 |
-
Q = Q.transpose(1, 2)
|
| 42 |
-
K = K.transpose(1, 2).transpose(2, 3)
|
| 43 |
-
V = V.transpose(1, 2)
|
| 44 |
-
|
| 45 |
-
attention = torch.matmul(Q, K)
|
| 46 |
-
|
| 47 |
-
attention = attention / np.sqrt(self.d_k)
|
| 48 |
-
|
| 49 |
-
attention = F.softmax(attention, dim=-1)
|
| 50 |
-
|
| 51 |
-
output = torch.matmul(attention, V)
|
| 52 |
-
|
| 53 |
-
output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
|
| 54 |
-
|
| 55 |
-
output = self.W_O(output)
|
| 56 |
-
|
| 57 |
-
output = self.dropout(output)
|
| 58 |
-
|
| 59 |
-
output = self.layer_norm(output + q)
|
| 60 |
-
|
| 61 |
-
return output, attention
|
| 62 |
-
|
| 63 |
-
class MultiHeadAttentionReciprocal(nn.Module):
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
| 67 |
-
|
| 68 |
-
super().__init__()
|
| 69 |
-
|
| 70 |
-
self.n_head = n_head
|
| 71 |
-
self.d_model = d_model
|
| 72 |
-
self.d_k = d_k
|
| 73 |
-
self.d_v = d_v
|
| 74 |
-
|
| 75 |
-
self.W_Q = nn.Linear(d_model, n_head*d_k)
|
| 76 |
-
self.W_K = nn.Linear(d_model, n_head*d_k)
|
| 77 |
-
self.W_V = nn.Linear(d_model, n_head*d_v)
|
| 78 |
-
self.W_O = nn.Linear(n_head*d_v, d_model)
|
| 79 |
-
self.W_V_2 = nn.Linear(d_model, n_head*d_v)
|
| 80 |
-
self.W_O_2 = nn.Linear(n_head*d_v, d_model)
|
| 81 |
-
|
| 82 |
-
self.layer_norm = nn.LayerNorm(d_model)
|
| 83 |
-
|
| 84 |
-
self.dropout = nn.Dropout(dropout)
|
| 85 |
-
|
| 86 |
-
self.layer_norm_2 = nn.LayerNorm(d_model)
|
| 87 |
-
|
| 88 |
-
self.dropout_2 = nn.Dropout(dropout)
|
| 89 |
-
|
| 90 |
-
def forward(self, q, k, v, v_2):
|
| 91 |
-
|
| 92 |
-
batch, len_q, _ = q.size()
|
| 93 |
-
batch, len_k, _ = k.size()
|
| 94 |
-
batch, len_v, _ = v.size()
|
| 95 |
-
batch, len_v_2, _ = v_2.size()
|
| 96 |
-
|
| 97 |
-
Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
|
| 98 |
-
K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
|
| 99 |
-
V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
|
| 100 |
-
V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v])
|
| 101 |
-
|
| 102 |
-
Q = Q.transpose(1, 2)
|
| 103 |
-
K = K.transpose(1, 2).transpose(2, 3)
|
| 104 |
-
V = V.transpose(1, 2)
|
| 105 |
-
V_2 = V_2.transpose(1,2)
|
| 106 |
-
|
| 107 |
-
attention = torch.matmul(Q, K)
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
attention = attention /np.sqrt(self.d_k)
|
| 111 |
-
|
| 112 |
-
attention_2 = attention.transpose(-2, -1)
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
attention = F.softmax(attention, dim=-1)
|
| 117 |
-
|
| 118 |
-
attention_2 = F.softmax(attention_2, dim=-1)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
output = torch.matmul(attention, V)
|
| 122 |
-
|
| 123 |
-
output_2 = torch.matmul(attention_2, V_2)
|
| 124 |
-
|
| 125 |
-
output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
|
| 126 |
-
|
| 127 |
-
output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head])
|
| 128 |
-
|
| 129 |
-
output = self.W_O(output)
|
| 130 |
-
|
| 131 |
-
output_2 = self.W_O_2(output_2)
|
| 132 |
-
|
| 133 |
-
output = self.dropout(output)
|
| 134 |
-
|
| 135 |
-
output = self.layer_norm(output + q)
|
| 136 |
-
|
| 137 |
-
output_2 = self.dropout(output_2)
|
| 138 |
-
|
| 139 |
-
output_2 = self.layer_norm(output_2 + k)
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
return output, output_2, attention, attention_2
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
class FFN(nn.Module):
|
| 146 |
-
|
| 147 |
-
def __init__(self, d_in, d_hid, dropout=0.1):
|
| 148 |
-
super().__init__()
|
| 149 |
-
|
| 150 |
-
self.layer_1 = nn.Conv1d(d_in, d_hid,1)
|
| 151 |
-
self.layer_2 = nn.Conv1d(d_hid, d_in,1)
|
| 152 |
-
self.relu = nn.ReLU()
|
| 153 |
-
self.layer_norm = nn.LayerNorm(d_in)
|
| 154 |
-
|
| 155 |
-
self.dropout = nn.Dropout(dropout)
|
| 156 |
-
|
| 157 |
-
def forward(self, x):
|
| 158 |
-
|
| 159 |
-
residual = x
|
| 160 |
-
output = self.layer_1(x.transpose(1, 2))
|
| 161 |
-
|
| 162 |
-
output = self.relu(output)
|
| 163 |
-
|
| 164 |
-
output = self.layer_2(output)
|
| 165 |
-
|
| 166 |
-
output = self.dropout(output)
|
| 167 |
-
|
| 168 |
-
output = self.layer_norm(output.transpose(1, 2)+residual)
|
| 169 |
-
|
| 170 |
-
return output
|
| 171 |
-
|
| 172 |
-
class ConvLayer(nn.Module):
|
| 173 |
-
def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
|
| 174 |
-
super(ConvLayer, self).__init__()
|
| 175 |
-
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
| 176 |
-
self.relu = nn.ReLU()
|
| 177 |
-
|
| 178 |
-
def forward(self, x):
|
| 179 |
-
out = self.conv(x)
|
| 180 |
-
out = self.relu(out)
|
| 181 |
-
return out
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
class DilatedCNN(nn.Module):
|
| 185 |
-
def __init__(self, d_model, d_hidden):
|
| 186 |
-
super(DilatedCNN, self).__init__()
|
| 187 |
-
self.first_ = nn.ModuleList()
|
| 188 |
-
self.second_ = nn.ModuleList()
|
| 189 |
-
self.third_ = nn.ModuleList()
|
| 190 |
-
|
| 191 |
-
dilation_tuple = (1, 2, 3)
|
| 192 |
-
dim_in_tuple = (d_model, d_hidden, d_hidden)
|
| 193 |
-
dim_out_tuple = (d_hidden, d_hidden, d_hidden)
|
| 194 |
-
|
| 195 |
-
for i, dilation_rate in enumerate(dilation_tuple):
|
| 196 |
-
self.first_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=3, padding=dilation_rate,
|
| 197 |
-
dilation=dilation_rate))
|
| 198 |
-
|
| 199 |
-
for i, dilation_rate in enumerate(dilation_tuple):
|
| 200 |
-
self.second_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=5, padding=2*dilation_rate,
|
| 201 |
-
dilation=dilation_rate))
|
| 202 |
-
|
| 203 |
-
for i, dilation_rate in enumerate(dilation_tuple):
|
| 204 |
-
self.third_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=7, padding=3*dilation_rate,
|
| 205 |
-
dilation=dilation_rate))
|
| 206 |
-
|
| 207 |
-
def forward(self, protein_seq_enc):
|
| 208 |
-
# pdb.set_trace()
|
| 209 |
-
protein_seq_enc = protein_seq_enc.transpose(1, 2) # protein_seq_enc's shape: B*L*d_model -> B*d_model*L
|
| 210 |
-
|
| 211 |
-
first_embedding = protein_seq_enc
|
| 212 |
-
second_embedding = protein_seq_enc
|
| 213 |
-
third_embedding = protein_seq_enc
|
| 214 |
-
|
| 215 |
-
for i in range(len(self.first_)):
|
| 216 |
-
first_embedding = self.first_[i](first_embedding)
|
| 217 |
-
|
| 218 |
-
for i in range(len(self.second_)):
|
| 219 |
-
second_embedding = self.second_[i](second_embedding)
|
| 220 |
-
|
| 221 |
-
for i in range(len(self.third_)):
|
| 222 |
-
third_embedding = self.third_[i](third_embedding)
|
| 223 |
-
|
| 224 |
-
# pdb.set_trace()
|
| 225 |
-
|
| 226 |
-
protein_seq_enc = first_embedding + second_embedding + third_embedding
|
| 227 |
-
|
| 228 |
-
return protein_seq_enc.transpose(1, 2)
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
class ReciprocalLayerwithCNN(nn.Module):
|
| 232 |
-
|
| 233 |
-
def __init__(self, d_model, d_inner, d_hidden, n_head, d_k, d_v):
|
| 234 |
-
super().__init__()
|
| 235 |
-
|
| 236 |
-
self.cnn = DilatedCNN(d_model, d_hidden)
|
| 237 |
-
|
| 238 |
-
self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
|
| 239 |
-
|
| 240 |
-
self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
|
| 241 |
-
|
| 242 |
-
self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_hidden, d_k, d_v)
|
| 243 |
-
|
| 244 |
-
self.ffn_seq = FFN(d_hidden, d_inner)
|
| 245 |
-
|
| 246 |
-
self.ffn_protein = FFN(d_hidden, d_inner)
|
| 247 |
-
|
| 248 |
-
def forward(self, sequence_enc, protein_seq_enc):
|
| 249 |
-
# pdb.set_trace() # protein_seq_enc.shape = B * L * d_model
|
| 250 |
-
protein_seq_enc = self.cnn(protein_seq_enc)
|
| 251 |
-
prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
|
| 252 |
-
|
| 253 |
-
seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
|
| 254 |
-
|
| 255 |
-
prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
|
| 256 |
-
|
| 257 |
-
prot_enc = self.ffn_protein(prot_enc)
|
| 258 |
-
|
| 259 |
-
seq_enc = self.ffn_seq(seq_enc)
|
| 260 |
-
|
| 261 |
-
return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
class ReciprocalLayer(nn.Module):
|
| 265 |
-
|
| 266 |
-
def __init__(self, d_model, d_inner, n_head, d_k, d_v):
|
| 267 |
-
|
| 268 |
-
super().__init__()
|
| 269 |
-
|
| 270 |
-
self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
|
| 271 |
-
|
| 272 |
-
self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
|
| 273 |
-
|
| 274 |
-
self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model, d_k, d_v)
|
| 275 |
-
|
| 276 |
-
self.ffn_seq = FFN(d_model, d_inner)
|
| 277 |
-
|
| 278 |
-
self.ffn_protein = FFN(d_model, d_inner)
|
| 279 |
-
|
| 280 |
-
def forward(self, sequence_enc, protein_seq_enc):
|
| 281 |
-
prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
|
| 282 |
-
|
| 283 |
-
seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
|
| 287 |
-
prot_enc = self.ffn_protein(prot_enc)
|
| 288 |
-
|
| 289 |
-
seq_enc = self.ffn_seq(seq_enc)
|
| 290 |
-
|
| 291 |
-
return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train/permeability_xg.py
DELETED
|
@@ -1,186 +0,0 @@
|
|
| 1 |
-
import pandas as pd
|
| 2 |
-
import numpy as np
|
| 3 |
-
import optuna
|
| 4 |
-
from optuna.trial import TrialState
|
| 5 |
-
from rdkit import Chem
|
| 6 |
-
from rdkit.Chem import AllChem
|
| 7 |
-
from sklearn.metrics import mean_squared_error
|
| 8 |
-
from sklearn.model_selection import train_test_split
|
| 9 |
-
import xgboost as xgb
|
| 10 |
-
import os
|
| 11 |
-
from datasets import load_from_disk
|
| 12 |
-
from scipy.stats import spearmanr
|
| 13 |
-
import matplotlib.pyplot as plt
|
| 14 |
-
|
| 15 |
-
base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
|
| 16 |
-
|
| 17 |
-
def save_and_plot_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, output_path):
|
| 18 |
-
os.makedirs(output_path, exist_ok=True)
|
| 19 |
-
|
| 20 |
-
# Save training predictions
|
| 21 |
-
train_df = pd.DataFrame({'True Permeability': y_true_train, 'Predicted Permeability': y_pred_train})
|
| 22 |
-
train_df.to_csv(os.path.join(output_path, 'train_predictions.csv'), index=False)
|
| 23 |
-
|
| 24 |
-
# Save validation predictions
|
| 25 |
-
val_df = pd.DataFrame({'True Permeability': y_true_val, 'Predicted Permeability': y_pred_val})
|
| 26 |
-
val_df.to_csv(os.path.join(output_path, 'val_predictions.csv'), index=False)
|
| 27 |
-
|
| 28 |
-
# Plot training predictions
|
| 29 |
-
plot_correlation(
|
| 30 |
-
y_true_train,
|
| 31 |
-
y_pred_train,
|
| 32 |
-
title="Training Set Correlation Plot",
|
| 33 |
-
output_file=os.path.join(output_path, 'train_correlation.png'),
|
| 34 |
-
)
|
| 35 |
-
|
| 36 |
-
# Plot validation predictions
|
| 37 |
-
plot_correlation(
|
| 38 |
-
y_true_val,
|
| 39 |
-
y_pred_val,
|
| 40 |
-
title="Validation Set Correlation Plot",
|
| 41 |
-
output_file=os.path.join(output_path, 'val_correlation.png'),
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
def plot_correlation(y_true, y_pred, title, output_file):
|
| 45 |
-
spearman_corr, _ = spearmanr(y_true, y_pred)
|
| 46 |
-
|
| 47 |
-
# Scatter plot
|
| 48 |
-
plt.figure(figsize=(10, 8))
|
| 49 |
-
plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF')
|
| 50 |
-
plt.plot([min(y_true), max(y_true)], [min(y_true), max(y_true)], color='teal', linestyle='--', label='Ideal fit')
|
| 51 |
-
|
| 52 |
-
# Add annotations
|
| 53 |
-
plt.title(f"{title}\nSpearman Correlation: {spearman_corr:.3f}")
|
| 54 |
-
plt.xlabel("True Permeability (logP)")
|
| 55 |
-
plt.ylabel("Predicted Affinity (logP)")
|
| 56 |
-
plt.legend()
|
| 57 |
-
|
| 58 |
-
# Save and show the plot
|
| 59 |
-
plt.tight_layout()
|
| 60 |
-
plt.savefig(output_file)
|
| 61 |
-
plt.show()
|
| 62 |
-
|
| 63 |
-
# Load dataset
|
| 64 |
-
dataset = load_from_disk(f'{base_path}/data/permeability')
|
| 65 |
-
|
| 66 |
-
# Extract sequences, labels, and embeddings
|
| 67 |
-
sequences = np.stack(dataset['sequence'])
|
| 68 |
-
labels = np.stack(dataset['labels']) # Regression labels
|
| 69 |
-
embeddings = np.stack(dataset['embedding']) # Pre-trained embeddings
|
| 70 |
-
|
| 71 |
-
# Function to compute Morgan fingerprints
|
| 72 |
-
def compute_morgan_fingerprints(smiles_list, radius=2, n_bits=2048):
|
| 73 |
-
fps = []
|
| 74 |
-
for smiles in smiles_list:
|
| 75 |
-
mol = Chem.MolFromSmiles(smiles)
|
| 76 |
-
if mol is not None:
|
| 77 |
-
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
|
| 78 |
-
fps.append(np.array(fp))
|
| 79 |
-
else:
|
| 80 |
-
# If the SMILES string is invalid, use a zero vector
|
| 81 |
-
fps.append(np.zeros(n_bits))
|
| 82 |
-
print(f"Invalid SMILES: {smiles}")
|
| 83 |
-
return np.array(fps)
|
| 84 |
-
|
| 85 |
-
# Compute Morgan fingerprints for the sequences
|
| 86 |
-
#morgan_fingerprints = compute_morgan_fingerprints(sequences)
|
| 87 |
-
|
| 88 |
-
# Concatenate embeddings with Morgan fingerprints
|
| 89 |
-
#input_features = np.concatenate([embeddings, morgan_fingerprints], axis=1)
|
| 90 |
-
input_features = embeddings
|
| 91 |
-
|
| 92 |
-
# Initialize global variables
|
| 93 |
-
best_model_path = f"{base_path}/src/permeability"
|
| 94 |
-
os.makedirs(best_model_path, exist_ok=True)
|
| 95 |
-
|
| 96 |
-
def trial_info_callback(study, trial):
|
| 97 |
-
if study.best_trial == trial:
|
| 98 |
-
print(f"Trial {trial.number}:")
|
| 99 |
-
print(f" MSE: {trial.value}")
|
| 100 |
-
|
| 101 |
-
def objective(trial):
|
| 102 |
-
# Define hyperparameters
|
| 103 |
-
params = {
|
| 104 |
-
'objective': 'reg:squarederror',
|
| 105 |
-
'lambda': trial.suggest_float('lambda', 0.1, 10.0, log=True),
|
| 106 |
-
'alpha': trial.suggest_float('alpha', 0.1, 10.0, log=True),
|
| 107 |
-
'gamma': trial.suggest_float('gamma', 0, 5),
|
| 108 |
-
'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
|
| 109 |
-
'subsample': trial.suggest_float('subsample', 0.6, 0.9),
|
| 110 |
-
'learning_rate': trial.suggest_float('learning_rate', 1e-5, 0.1),
|
| 111 |
-
'max_depth': trial.suggest_int('max_depth', 2, 30),
|
| 112 |
-
'min_child_weight': trial.suggest_int('min_child_weight', 1, 20),
|
| 113 |
-
'tree_method': 'hist',
|
| 114 |
-
'scale_pos_weight': trial.suggest_float('scale_pos_weight', 0.5, 10.0, log=True),
|
| 115 |
-
'device': 'cuda:6',
|
| 116 |
-
}
|
| 117 |
-
num_boost_round = trial.suggest_int('num_boost_round', 10, 1000)
|
| 118 |
-
|
| 119 |
-
# Train-validation split
|
| 120 |
-
X_train, X_val, y_train, y_val = train_test_split(input_features, labels, test_size=0.2, random_state=42)
|
| 121 |
-
|
| 122 |
-
# Convert data to DMatrix
|
| 123 |
-
dtrain = xgb.DMatrix(X_train, label=y_train)
|
| 124 |
-
dvalid = xgb.DMatrix(X_val, label=y_val)
|
| 125 |
-
|
| 126 |
-
# Train XGBoost
|
| 127 |
-
model = xgb.train(
|
| 128 |
-
params=params,
|
| 129 |
-
dtrain=dtrain,
|
| 130 |
-
num_boost_round=num_boost_round,
|
| 131 |
-
evals=[(dvalid, "validation")],
|
| 132 |
-
early_stopping_rounds=50,
|
| 133 |
-
verbose_eval=False,
|
| 134 |
-
)
|
| 135 |
-
|
| 136 |
-
# Predict and evaluate
|
| 137 |
-
preds_train = model.predict(dtrain)
|
| 138 |
-
preds_val = model.predict(dvalid)
|
| 139 |
-
|
| 140 |
-
mse = mean_squared_error(y_val, preds_val)
|
| 141 |
-
|
| 142 |
-
# Calculate Spearman Rank Correlation for both train and validation
|
| 143 |
-
spearman_train, _ = spearmanr(y_train, preds_train)
|
| 144 |
-
spearman_val, _ = spearmanr(y_val, preds_val)
|
| 145 |
-
print(f"Train Spearman: {spearman_train:.4f}, Val Spearman: {spearman_val:.4f}")
|
| 146 |
-
|
| 147 |
-
# Save the best model
|
| 148 |
-
if trial.study.user_attrs.get("best_mse", np.inf) > mse:
|
| 149 |
-
trial.study.set_user_attr("best_mse", mse)
|
| 150 |
-
trial.study.set_user_attr("best_spearman_train", spearman_train)
|
| 151 |
-
trial.study.set_user_attr("best_spearman_val", spearman_val)
|
| 152 |
-
trial.study.set_user_attr("best_trial", trial.number)
|
| 153 |
-
model.save_model(os.path.join(best_model_path, "best_model.json"))
|
| 154 |
-
save_and_plot_predictions(y_train, preds_train, y_val, preds_val, best_model_path)
|
| 155 |
-
print(f"✓ NEW BEST! Trial {trial.number}: MSE={mse:.4f}, Train Spearman={spearman_train:.4f}, Val Spearman={spearman_val:.4f}")
|
| 156 |
-
|
| 157 |
-
return mse
|
| 158 |
-
|
| 159 |
-
if __name__ == "__main__":
|
| 160 |
-
study = optuna.create_study(direction="minimize", pruner=optuna.pruners.MedianPruner())
|
| 161 |
-
study.optimize(objective, n_trials=200, callbacks=[trial_info_callback])
|
| 162 |
-
|
| 163 |
-
# Prepare summary text
|
| 164 |
-
summary = []
|
| 165 |
-
summary.append("\n" + "="*60)
|
| 166 |
-
summary.append("OPTIMIZATION COMPLETE")
|
| 167 |
-
summary.append("="*60)
|
| 168 |
-
summary.append(f"Number of finished trials: {len(study.trials)}")
|
| 169 |
-
summary.append(f"\nBest Trial: #{study.user_attrs.get('best_trial', 'N/A')}")
|
| 170 |
-
summary.append(f"Best MSE: {study.best_trial.value:.4f}")
|
| 171 |
-
summary.append(f"Best Training Spearman Correlation: {study.user_attrs.get('best_spearman_train', None):.4f}")
|
| 172 |
-
summary.append(f"Best Validation Spearman Correlation: {study.user_attrs.get('best_spearman_val', None):.4f}")
|
| 173 |
-
summary.append(f"\nBest hyperparameters:")
|
| 174 |
-
for key, value in study.best_trial.params.items():
|
| 175 |
-
summary.append(f" {key}: {value}")
|
| 176 |
-
summary.append("="*60)
|
| 177 |
-
|
| 178 |
-
# Print to console
|
| 179 |
-
for line in summary:
|
| 180 |
-
print(line)
|
| 181 |
-
|
| 182 |
-
# Save to file
|
| 183 |
-
metrics_file = os.path.join(best_model_path, "optimization_metrics.txt")
|
| 184 |
-
with open(metrics_file, 'w') as f:
|
| 185 |
-
f.write('\n'.join(summary))
|
| 186 |
-
print(f"\n✓ Metrics saved to: {metrics_file}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|