ribesstefano
commited on
Commit
•
f3d4b52
1
Parent(s):
a8d1800
Updated README
Browse files- README.md +30 -13
- notebooks/best_fingerprint_search.ipynb +0 -0
- notebooks/cell_type_embedding.ipynb +49 -1
- notebooks/plot_experimental_results.ipynb +0 -0
- notebooks/predict_unknown_protacs.ipynb +0 -0
- protac_degradation_predictor/models/{best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt → best_model_n0_random-epoch=13-val_acc=0.83-val_roc_auc=0.841-v1.ckpt} +2 -2
- protac_degradation_predictor/models/best_model_n1_random-epoch=8-test_acc=0.78-test_roc_auc=0.851.ckpt +3 -0
- protac_degradation_predictor/models/best_model_n2_random-epoch=9-val_acc=0.80-val_roc_auc=0.841-v1.ckpt +3 -0
- protac_degradation_predictor/models/cv_model_random_fold0-epoch=12-val_acc=0.86-val_roc_auc=0.905.ckpt +3 -0
- protac_degradation_predictor/models/cv_model_random_fold1-epoch=16-val_acc=0.86-val_roc_auc=0.933.ckpt +3 -0
- protac_degradation_predictor/models/cv_model_random_fold2-epoch=16-val_acc=0.86-val_roc_auc=0.908.ckpt +3 -0
- protac_degradation_predictor/models/cv_model_random_fold3-epoch=15-val_acc=0.89-val_roc_auc=0.930.ckpt +3 -0
- protac_degradation_predictor/models/cv_model_random_fold4-epoch=15-val_acc=0.88-val_roc_auc=0.928.ckpt +3 -0
- protac_degradation_predictor/optuna_utils.py +1 -1
- protac_degradation_predictor/protac_degradation_predictor.py +55 -27
- protac_degradation_predictor/pytorch_models.py +25 -24
- src/plot_experiment_results.py +0 -1
- src/run_experiments.py +0 -1
README.md
CHANGED
@@ -1,24 +1,38 @@
|
|
1 |
-
![Maturity level-0](https://img.shields.io/badge/Maturity%20Level-ML--0-red)
|
2 |
|
3 |
-
# PROTAC-Degradation-Predictor
|
4 |
|
5 |
-
|
|
|
|
|
6 |
|
7 |
-
|
8 |
|
9 |
-
|
|
|
|
|
10 |
|
11 |
-
##
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
```bash
|
16 |
pip install .
|
17 |
```
|
18 |
|
19 |
-
##
|
20 |
|
21 |
-
|
22 |
|
23 |
```python
|
24 |
import protac_degradation_predictor as pdp
|
@@ -33,16 +47,19 @@ active_protac = pdp.is_protac_active(
|
|
33 |
e3_ligase,
|
34 |
target_uniprot,
|
35 |
cell_line,
|
36 |
-
device='
|
37 |
proba_threshold=0.5, # Default value
|
38 |
)
|
39 |
|
40 |
print(f'The given PROTAC is: {"active" if active_protac else "inactive"}')
|
41 |
```
|
42 |
|
43 |
-
|
44 |
-
|
|
|
45 |
|
|
|
46 |
|
|
|
47 |
|
48 |
-
|
|
|
1 |
+
<!-- ![Maturity level-0](https://img.shields.io/badge/Maturity%20Level-ML--0-red)
|
2 |
|
3 |
+
# PROTAC-Degradation-Predictor -->
|
4 |
|
5 |
+
<p align="center">
|
6 |
+
<img src="https://img.shields.io/badge/Maturity%20Level-ML--0-red" alt="Maturity level-0">
|
7 |
+
</p>
|
8 |
|
9 |
+
<h1 align="center">PROTAC-Degradation-Predictor</h1>
|
10 |
|
11 |
+
<p align="center">
|
12 |
+
A machine learning-based tool for predicting PROTAC protein degradation activity.
|
13 |
+
</p>
|
14 |
|
15 |
+
## 📚 Table of Contents
|
16 |
|
17 |
+
- [Data Curation](#-data-curation)
|
18 |
+
- [Installation](#-installation)
|
19 |
+
- [Usage](#-usage)
|
20 |
+
|
21 |
+
## 📝 Data Curation
|
22 |
+
|
23 |
+
The code for data curation can be found in the Jupyter notebook [`data_curation.ipynb`](notebooks/data_curation.ipynb).
|
24 |
+
|
25 |
+
## 🚀 Installation
|
26 |
+
|
27 |
+
To install the package, open your terminal and run the following command:
|
28 |
|
29 |
```bash
|
30 |
pip install .
|
31 |
```
|
32 |
|
33 |
+
## 🎯 Usage
|
34 |
|
35 |
+
After installing the package, you can use it as follows:
|
36 |
|
37 |
```python
|
38 |
import protac_degradation_predictor as pdp
|
|
|
47 |
e3_ligase,
|
48 |
target_uniprot,
|
49 |
cell_line,
|
50 |
+
device='cuda', # Default to 'cpu'
|
51 |
proba_threshold=0.5, # Default value
|
52 |
)
|
53 |
|
54 |
print(f'The given PROTAC is: {"active" if active_protac else "inactive"}')
|
55 |
```
|
56 |
|
57 |
+
This example demonstrates how to predict the activity of a PROTAC molecule. The `is_protac_active` function takes the SMILES string of the PROTAC, the E3 ligase, the UniProt ID of the target protein, and the cell line as inputs. It returns whether the PROTAC is active or not.
|
58 |
+
|
59 |
+
## 📈 Training
|
60 |
|
61 |
+
The code for training the model can be found in the file [`run_experiments.py`](src/run_experiments.py).
|
62 |
|
63 |
+
## 📜 License
|
64 |
|
65 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
notebooks/best_fingerprint_search.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/cell_type_embedding.ipynb
CHANGED
@@ -869,6 +869,13 @@
|
|
869 |
"unique_columns_ranking"
|
870 |
]
|
871 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
872 |
{
|
873 |
"cell_type": "code",
|
874 |
"execution_count": 12,
|
@@ -989,6 +996,47 @@
|
|
989 |
" pickle.dump(cell2description, f)"
|
990 |
]
|
991 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
992 |
{
|
993 |
"cell_type": "markdown",
|
994 |
"metadata": {},
|
@@ -1005,7 +1053,7 @@
|
|
1005 |
},
|
1006 |
{
|
1007 |
"cell_type": "code",
|
1008 |
-
"execution_count":
|
1009 |
"metadata": {},
|
1010 |
"outputs": [],
|
1011 |
"source": [
|
|
|
869 |
"unique_columns_ranking"
|
870 |
]
|
871 |
},
|
872 |
+
{
|
873 |
+
"cell_type": "markdown",
|
874 |
+
"metadata": {},
|
875 |
+
"source": [
|
876 |
+
"genome ancestry, karyotypic information, senescence, biotechnology, virology, caution, donor information, sequence variation, characteristics, transfected with, monoclonal antibody target, HLA typing, knockout cell, microsatellite instability, hierarchy (HI), breed/subspecies, derived from site, population, group, monoclonal antibody isotype, cell type, transformant, selected for resistance to, category (CA)."
|
877 |
+
]
|
878 |
+
},
|
879 |
{
|
880 |
"cell_type": "code",
|
881 |
"execution_count": 12,
|
|
|
996 |
" pickle.dump(cell2description, f)"
|
997 |
]
|
998 |
},
|
999 |
+
{
|
1000 |
+
"cell_type": "markdown",
|
1001 |
+
"metadata": {},
|
1002 |
+
"source": [
|
1003 |
+
"\\begin{figure*}[t!]\n",
|
1004 |
+
" \\centering\n",
|
1005 |
+
" \\begin{subfigure}{0.5\\textwidth}\n",
|
1006 |
+
" \\centering\n",
|
1007 |
+
" \\includegraphics[width=0.99\\columnwidth]{plots/pytorch_performance_Accuracy.pdf}\n",
|
1008 |
+
" \\caption{}\n",
|
1009 |
+
" \\label{fig:pytorch_accuracy}\n",
|
1010 |
+
" \\end{subfigure}%\n",
|
1011 |
+
" \\begin{subfigure}{0.5\\textwidth}\n",
|
1012 |
+
" \\centering\n",
|
1013 |
+
" \\includegraphics[width=0.99\\columnwidth]{plots/pytorch_performance_ROC AUC.pdf}\n",
|
1014 |
+
" \\caption{}\n",
|
1015 |
+
" \\label{fig:pytorch_roc_auc}\n",
|
1016 |
+
" \\end{subfigure}\\\\%\n",
|
1017 |
+
" \\begin{subfigure}{0.5\\textwidth}\n",
|
1018 |
+
" \\centering\n",
|
1019 |
+
" \\includegraphics[width=0.99\\columnwidth]{plots/pytorch_performance_F1 Score.pdf}\n",
|
1020 |
+
" \\caption{}\n",
|
1021 |
+
" \\label{fig:pytorch_f1_score}\n",
|
1022 |
+
" \\end{subfigure}%\n",
|
1023 |
+
" \\begin{subfigure}{0.5\\textwidth}\n",
|
1024 |
+
" \\centering\n",
|
1025 |
+
" \\includegraphics[width=0.99\\columnwidth]{plots/pytorch_performance_Precision.pdf}\n",
|
1026 |
+
" \\caption{}\n",
|
1027 |
+
" \\label{fig:pytorch_precision}\n",
|
1028 |
+
" \\end{subfigure}\\\\%\n",
|
1029 |
+
" \\begin{subfigure}{0.5\\textwidth}\n",
|
1030 |
+
" \\centering\n",
|
1031 |
+
" \\includegraphics[width=0.99\\columnwidth]{plots/pytorch_performance_Recall.pdf}\n",
|
1032 |
+
" \\caption{}\n",
|
1033 |
+
" \\label{fig:pytorch_recall}\n",
|
1034 |
+
" \\end{subfigure}%\n",
|
1035 |
+
" \\caption{Performance metrics of the proposed deep learning models. (a) ROC-AUC. (b) F1 score. (c) Precision. (d) Recall.}\n",
|
1036 |
+
" \\label{fig:pytorch_performance}\n",
|
1037 |
+
"\\end{figure*}"
|
1038 |
+
]
|
1039 |
+
},
|
1040 |
{
|
1041 |
"cell_type": "markdown",
|
1042 |
"metadata": {},
|
|
|
1053 |
},
|
1054 |
{
|
1055 |
"cell_type": "code",
|
1056 |
+
"execution_count": 1,
|
1057 |
"metadata": {},
|
1058 |
"outputs": [],
|
1059 |
"source": [
|
notebooks/plot_experimental_results.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/predict_unknown_protacs.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
protac_degradation_predictor/models/{best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt → best_model_n0_random-epoch=13-val_acc=0.83-val_roc_auc=0.841-v1.ckpt}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:497045f4d5f3bcf859db7339f971d9a7c6c2881121fe8841f3c16d3d17f8c3fa
|
3 |
+
size 5127967
|
protac_degradation_predictor/models/best_model_n1_random-epoch=8-test_acc=0.78-test_roc_auc=0.851.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bd7c412c10651dda9c528d57057f630c167a724f86f3dbd933122991406e0563
|
3 |
+
size 2565407
|
protac_degradation_predictor/models/best_model_n2_random-epoch=9-val_acc=0.80-val_roc_auc=0.841-v1.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a32c2eaa95d7f887912074426e092976a384c303b2e07cb05db2e8e5f1a48870
|
3 |
+
size 5127967
|
protac_degradation_predictor/models/cv_model_random_fold0-epoch=12-val_acc=0.86-val_roc_auc=0.905.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aaec6c36b21650fdca607be7978f38b5ee07e71a9d747c14402806de540148b7
|
3 |
+
size 2565087
|
protac_degradation_predictor/models/cv_model_random_fold1-epoch=16-val_acc=0.86-val_roc_auc=0.933.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:59334fe5d4013c9575b827ab820bb2dcb9ab64b597cd6571557e18f3cc9707fe
|
3 |
+
size 2565407
|
protac_degradation_predictor/models/cv_model_random_fold2-epoch=16-val_acc=0.86-val_roc_auc=0.908.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fe1ce77c1bf664a8ec6ee5a7a83c685b18fefc42ad0c3069dde2e88e77d0a07e
|
3 |
+
size 5128095
|
protac_degradation_predictor/models/cv_model_random_fold3-epoch=15-val_acc=0.89-val_roc_auc=0.930.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3606040e5077e0b3d9e00fd3ab8c94decb20cb84fc59a9aad0f55f3620f36b8a
|
3 |
+
size 2565087
|
protac_degradation_predictor/models/cv_model_random_fold4-epoch=15-val_acc=0.88-val_roc_auc=0.928.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e634ff4ebb89afeff3e31ff47d8215c8906e1133530630c0abcc938408dffc91
|
3 |
+
size 2565215
|
protac_degradation_predictor/optuna_utils.py
CHANGED
@@ -104,7 +104,7 @@ def get_majority_vote_metrics(
|
|
104 |
'test_roc_auc': AUROC(task='binary')(test_preds, y).item(),
|
105 |
'test_precision': Precision(task='binary')(test_preds, y).item(),
|
106 |
'test_recall': Recall(task='binary')(test_preds, y).item(),
|
107 |
-
'
|
108 |
}
|
109 |
return majority_vote_metrics
|
110 |
|
|
|
104 |
'test_roc_auc': AUROC(task='binary')(test_preds, y).item(),
|
105 |
'test_precision': Precision(task='binary')(test_preds, y).item(),
|
106 |
'test_recall': Recall(task='binary')(test_preds, y).item(),
|
107 |
+
'test_f1_score': F1Score(task='binary')(test_preds, y).item(),
|
108 |
}
|
109 |
return majority_vote_metrics
|
110 |
|
protac_degradation_predictor/protac_degradation_predictor.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import pkg_resources
|
2 |
import logging
|
3 |
-
from typing import List
|
4 |
|
5 |
from .pytorch_models import PROTAC_Model, load_model
|
6 |
from .data_utils import (
|
@@ -20,12 +20,21 @@ def get_protac_active_proba(
|
|
20 |
e3_ligase: str | List[str],
|
21 |
target_uniprot: str | List[str],
|
22 |
cell_line: str | List[str],
|
23 |
-
device:
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
protein2embedding = load_protein2embedding()
|
30 |
cell2embedding = load_cell2embedding('data/cell2embedding.pkl')
|
31 |
|
@@ -60,32 +69,47 @@ def get_protac_active_proba(
|
|
60 |
smiles_emb = [get_fingerprint(protac_smiles)]
|
61 |
|
62 |
# Convert to torch tensors
|
63 |
-
poi_emb = torch.tensor(poi_emb).to(device)
|
64 |
-
e3_emb = torch.tensor(e3_emb).to(device)
|
65 |
-
cell_emb = torch.tensor(cell_emb).to(device)
|
66 |
-
smiles_emb = torch.tensor(smiles_emb).to(device)
|
67 |
-
|
68 |
-
pred = model(
|
69 |
-
poi_emb,
|
70 |
-
e3_emb,
|
71 |
-
cell_emb,
|
72 |
-
smiles_emb,
|
73 |
-
prescaled_embeddings=False, # Trigger automatic scaling
|
74 |
-
)
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
|
82 |
def is_protac_active(
|
83 |
-
protac_smiles: str,
|
84 |
-
e3_ligase: str,
|
85 |
-
target_uniprot: str,
|
86 |
-
cell_line: str,
|
87 |
device: str = 'cpu',
|
88 |
proba_threshold: float = 0.5,
|
|
|
|
|
89 |
) -> bool:
|
90 |
""" Predict whether a PROTAC is active or not.
|
91 |
|
@@ -106,5 +130,9 @@ def is_protac_active(
|
|
106 |
target_uniprot,
|
107 |
cell_line,
|
108 |
device,
|
|
|
109 |
)
|
110 |
-
|
|
|
|
|
|
|
|
1 |
import pkg_resources
|
2 |
import logging
|
3 |
+
from typing import List, Literal, Dict
|
4 |
|
5 |
from .pytorch_models import PROTAC_Model, load_model
|
6 |
from .data_utils import (
|
|
|
20 |
e3_ligase: str | List[str],
|
21 |
target_uniprot: str | List[str],
|
22 |
cell_line: str | List[str],
|
23 |
+
device: Literal['cpu', 'cuda'] = 'cpu',
|
24 |
+
use_models_from_cv: bool = False,
|
25 |
+
) -> Dict[str, np.ndarray]:
|
26 |
+
""" Predict the probability of a PROTAC being active.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
protac_smiles (str | List[str]): The SMILES of the PROTAC.
|
30 |
+
e3_ligase (str | List[str]): The Uniprot ID of the E3 ligase.
|
31 |
+
target_uniprot (str | List[str]): The Uniprot ID of the target protein.
|
32 |
+
cell_line (str | List[str]): The cell line identifier.
|
33 |
+
device (str): The device to run the model on.
|
34 |
|
35 |
+
Returns:
|
36 |
+
Dict[str, np.ndarray]: The predictions of the model.
|
37 |
+
"""
|
38 |
protein2embedding = load_protein2embedding()
|
39 |
cell2embedding = load_cell2embedding('data/cell2embedding.pkl')
|
40 |
|
|
|
69 |
smiles_emb = [get_fingerprint(protac_smiles)]
|
70 |
|
71 |
# Convert to torch tensors
|
72 |
+
poi_emb = torch.tensor(np.array(poi_emb)).to(device)
|
73 |
+
e3_emb = torch.tensor(np.array(e3_emb)).to(device)
|
74 |
+
cell_emb = torch.tensor(np.array(cell_emb)).to(device)
|
75 |
+
smiles_emb = torch.tensor(np.array(smiles_emb)).float().to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
+
models = {}
|
78 |
+
model_to_load = 'best_model' if not use_models_from_cv else 'cv_model'
|
79 |
+
# Load all models in pkg_resources that start with 'model_to_load'
|
80 |
+
for model_filename in pkg_resources.resource_listdir(__name__, 'models'):
|
81 |
+
if model_filename.startswith(model_to_load):
|
82 |
+
ckpt_path = pkg_resources.resource_stream(__name__, f'models/{model_filename}')
|
83 |
+
models[ckpt_path] = load_model(ckpt_path).to(device)
|
84 |
+
|
85 |
+
# Average the predictions of all models
|
86 |
+
preds = {}
|
87 |
+
for ckpt_path, model in models.items():
|
88 |
+
pred = model(
|
89 |
+
poi_emb,
|
90 |
+
e3_emb,
|
91 |
+
cell_emb,
|
92 |
+
smiles_emb,
|
93 |
+
prescaled_embeddings=False, # Normalization performed by the model
|
94 |
+
)
|
95 |
+
preds[ckpt_path] = sigmoid(pred).detach().numpy().flatten()
|
96 |
+
axis = 1 if isinstance(protac_smiles, list) else None
|
97 |
+
return {
|
98 |
+
'preds': np.array(list(preds.values())),
|
99 |
+
'mean': np.mean(list(preds.values()), axis=axis),
|
100 |
+
'majority_vote': np.mean(list(preds.values()), axis=axis) > 0.5,
|
101 |
+
}
|
102 |
|
103 |
|
104 |
def is_protac_active(
|
105 |
+
protac_smiles: str | List[str],
|
106 |
+
e3_ligase: str | List[str],
|
107 |
+
target_uniprot: str | List[str],
|
108 |
+
cell_line: str | List[str],
|
109 |
device: str = 'cpu',
|
110 |
proba_threshold: float = 0.5,
|
111 |
+
use_majority_vote: bool = False,
|
112 |
+
use_models_from_cv: bool = False,
|
113 |
) -> bool:
|
114 |
""" Predict whether a PROTAC is active or not.
|
115 |
|
|
|
130 |
target_uniprot,
|
131 |
cell_line,
|
132 |
device,
|
133 |
+
use_models_from_cv,
|
134 |
)
|
135 |
+
if use_majority_vote:
|
136 |
+
return pred['majority_vote']
|
137 |
+
else:
|
138 |
+
return pred['mean'] > proba_threshold
|
protac_degradation_predictor/pytorch_models.py
CHANGED
@@ -53,14 +53,6 @@ class PROTAC_Predictor(nn.Module):
|
|
53 |
disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles'
|
54 |
"""
|
55 |
super().__init__()
|
56 |
-
self.poi_emb_dim = poi_emb_dim
|
57 |
-
self.e3_emb_dim = e3_emb_dim
|
58 |
-
self.cell_emb_dim = cell_emb_dim
|
59 |
-
self.smiles_emb_dim = smiles_emb_dim
|
60 |
-
self.hidden_dim = hidden_dim
|
61 |
-
self.join_embeddings = join_embeddings
|
62 |
-
self.use_batch_norm = use_batch_norm
|
63 |
-
self.disabled_embeddings = disabled_embeddings
|
64 |
# Set our init args as class attributes
|
65 |
self.__dict__.update(locals())
|
66 |
|
@@ -126,12 +118,24 @@ class PROTAC_Predictor(nn.Module):
|
|
126 |
else:
|
127 |
if 'poi' not in self.disabled_embeddings:
|
128 |
embeddings.append(self.poi_fc(poi_emb))
|
|
|
|
|
|
|
129 |
if 'e3' not in self.disabled_embeddings:
|
130 |
embeddings.append(self.e3_fc(e3_emb))
|
|
|
|
|
|
|
131 |
if 'cell' not in self.disabled_embeddings:
|
132 |
embeddings.append(self.cell_fc(cell_emb))
|
|
|
|
|
|
|
133 |
if 'smiles' not in self.disabled_embeddings:
|
134 |
embeddings.append(self.smiles_emb(smiles_emb))
|
|
|
|
|
|
|
135 |
if self.join_embeddings == 'concat':
|
136 |
x = torch.cat(embeddings, dim=1)
|
137 |
elif self.join_embeddings == 'sum':
|
@@ -140,6 +144,8 @@ class PROTAC_Predictor(nn.Module):
|
|
140 |
x = torch.sum(embeddings, dim=1)
|
141 |
else:
|
142 |
x = embeddings[0]
|
|
|
|
|
143 |
x = F.relu(self.fc1(x))
|
144 |
x = self.bnorm(x) if self.use_batch_norm else self.self.dropout(x)
|
145 |
x = self.fc3(x)
|
@@ -185,19 +191,6 @@ class PROTAC_Model(pl.LightningModule):
|
|
185 |
apply_scaling (bool): Whether to apply scaling to the embeddings
|
186 |
"""
|
187 |
super().__init__()
|
188 |
-
self.poi_emb_dim = poi_emb_dim
|
189 |
-
self.e3_emb_dim = e3_emb_dim
|
190 |
-
self.cell_emb_dim = cell_emb_dim
|
191 |
-
self.smiles_emb_dim = smiles_emb_dim
|
192 |
-
self.hidden_dim = hidden_dim
|
193 |
-
self.batch_size = batch_size
|
194 |
-
self.learning_rate = learning_rate
|
195 |
-
self.join_embeddings = join_embeddings
|
196 |
-
self.train_dataset = train_dataset
|
197 |
-
self.val_dataset = val_dataset
|
198 |
-
self.test_dataset = test_dataset
|
199 |
-
self.disabled_embeddings = disabled_embeddings
|
200 |
-
self.apply_scaling = apply_scaling
|
201 |
# Set our init args as class attributes
|
202 |
self.__dict__.update(locals()) # Add arguments as attributes
|
203 |
# Save the arguments passed to init
|
@@ -265,6 +258,7 @@ class PROTAC_Model(pl.LightningModule):
|
|
265 |
self,
|
266 |
tensor: torch.Tensor,
|
267 |
scaler: StandardScaler,
|
|
|
268 |
) -> torch.Tensor:
|
269 |
"""Scale a tensor using a scaler. This is done to avoid using numpy
|
270 |
arrays (and stay on the same device).
|
@@ -280,7 +274,7 @@ class PROTAC_Model(pl.LightningModule):
|
|
280 |
if scaler.with_mean:
|
281 |
tensor -= torch.tensor(scaler.mean_, dtype=tensor.dtype, device=tensor.device)
|
282 |
if scaler.with_std:
|
283 |
-
tensor /= torch.tensor(scaler.scale_, dtype=tensor.dtype, device=tensor.device)
|
284 |
return tensor
|
285 |
|
286 |
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=True):
|
@@ -300,6 +294,14 @@ class PROTAC_Model(pl.LightningModule):
|
|
300 |
e3_emb = self.scale_tensor(e3_emb, self.scalers['E3 Ligase Uniprot'])
|
301 |
cell_emb = self.scale_tensor(cell_emb, self.scalers['Cell Line Identifier'])
|
302 |
smiles_emb = self.scale_tensor(smiles_emb, self.scalers['Smiles'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
|
304 |
|
305 |
def step(self, batch, batch_idx, stage):
|
@@ -624,5 +626,4 @@ def load_model(
|
|
624 |
# with other datasets...
|
625 |
# if model.apply_scaling:
|
626 |
# model.apply_scalers()
|
627 |
-
model.eval()
|
628 |
-
return model
|
|
|
53 |
disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles'
|
54 |
"""
|
55 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
# Set our init args as class attributes
|
57 |
self.__dict__.update(locals())
|
58 |
|
|
|
118 |
else:
|
119 |
if 'poi' not in self.disabled_embeddings:
|
120 |
embeddings.append(self.poi_fc(poi_emb))
|
121 |
+
if torch.isnan(embeddings[-1]).any():
|
122 |
+
raise ValueError("NaN values found in POI embeddings.")
|
123 |
+
|
124 |
if 'e3' not in self.disabled_embeddings:
|
125 |
embeddings.append(self.e3_fc(e3_emb))
|
126 |
+
if torch.isnan(embeddings[-1]).any():
|
127 |
+
raise ValueError("NaN values found in E3 embeddings.")
|
128 |
+
|
129 |
if 'cell' not in self.disabled_embeddings:
|
130 |
embeddings.append(self.cell_fc(cell_emb))
|
131 |
+
if torch.isnan(embeddings[-1]).any():
|
132 |
+
raise ValueError("NaN values found in cell embeddings.")
|
133 |
+
|
134 |
if 'smiles' not in self.disabled_embeddings:
|
135 |
embeddings.append(self.smiles_emb(smiles_emb))
|
136 |
+
if torch.isnan(embeddings[-1]).any():
|
137 |
+
raise ValueError("NaN values found in SMILES embeddings.")
|
138 |
+
|
139 |
if self.join_embeddings == 'concat':
|
140 |
x = torch.cat(embeddings, dim=1)
|
141 |
elif self.join_embeddings == 'sum':
|
|
|
144 |
x = torch.sum(embeddings, dim=1)
|
145 |
else:
|
146 |
x = embeddings[0]
|
147 |
+
if torch.isnan(x).any():
|
148 |
+
raise ValueError("NaN values found in sum of softmax-ed embeddings.")
|
149 |
x = F.relu(self.fc1(x))
|
150 |
x = self.bnorm(x) if self.use_batch_norm else self.self.dropout(x)
|
151 |
x = self.fc3(x)
|
|
|
191 |
apply_scaling (bool): Whether to apply scaling to the embeddings
|
192 |
"""
|
193 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
# Set our init args as class attributes
|
195 |
self.__dict__.update(locals()) # Add arguments as attributes
|
196 |
# Save the arguments passed to init
|
|
|
258 |
self,
|
259 |
tensor: torch.Tensor,
|
260 |
scaler: StandardScaler,
|
261 |
+
alpha: float = 1e-10,
|
262 |
) -> torch.Tensor:
|
263 |
"""Scale a tensor using a scaler. This is done to avoid using numpy
|
264 |
arrays (and stay on the same device).
|
|
|
274 |
if scaler.with_mean:
|
275 |
tensor -= torch.tensor(scaler.mean_, dtype=tensor.dtype, device=tensor.device)
|
276 |
if scaler.with_std:
|
277 |
+
tensor /= torch.tensor(scaler.scale_, dtype=tensor.dtype, device=tensor.device) + alpha
|
278 |
return tensor
|
279 |
|
280 |
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=True):
|
|
|
294 |
e3_emb = self.scale_tensor(e3_emb, self.scalers['E3 Ligase Uniprot'])
|
295 |
cell_emb = self.scale_tensor(cell_emb, self.scalers['Cell Line Identifier'])
|
296 |
smiles_emb = self.scale_tensor(smiles_emb, self.scalers['Smiles'])
|
297 |
+
if torch.isnan(poi_emb).any():
|
298 |
+
raise ValueError("NaN values found in POI embeddings.")
|
299 |
+
if torch.isnan(e3_emb).any():
|
300 |
+
raise ValueError("NaN values found in E3 embeddings.")
|
301 |
+
if torch.isnan(cell_emb).any():
|
302 |
+
raise ValueError("NaN values found in cell embeddings.")
|
303 |
+
if torch.isnan(smiles_emb).any():
|
304 |
+
raise ValueError("NaN values found in SMILES embeddings.")
|
305 |
return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
|
306 |
|
307 |
def step(self, batch, batch_idx, stage):
|
|
|
626 |
# with other datasets...
|
627 |
# if model.apply_scaling:
|
628 |
# model.apply_scalers()
|
629 |
+
return model.eval()
|
|
src/plot_experiment_results.py
CHANGED
@@ -331,7 +331,6 @@ def main():
|
|
331 |
]),
|
332 |
}
|
333 |
|
334 |
-
|
335 |
for split_type in ['random', 'tanimoto', 'uniprot']:
|
336 |
split_metrics = []
|
337 |
for i in range(n_models_for_test):
|
|
|
331 |
]),
|
332 |
}
|
333 |
|
|
|
334 |
for split_type in ['random', 'tanimoto', 'uniprot']:
|
335 |
split_metrics = []
|
336 |
for i in range(n_models_for_test):
|
src/run_experiments.py
CHANGED
@@ -8,7 +8,6 @@ from typing import Literal
|
|
8 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
9 |
|
10 |
import protac_degradation_predictor as pdp
|
11 |
-
from protac_degradation_predictor.optuna_utils import get_dataframe_stats
|
12 |
|
13 |
import pytorch_lightning as pl
|
14 |
from rdkit import Chem
|
|
|
8 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
9 |
|
10 |
import protac_degradation_predictor as pdp
|
|
|
11 |
|
12 |
import pytorch_lightning as pl
|
13 |
from rdkit import Chem
|