ribesstefano commited on
Commit
f3d4b52
1 Parent(s): a8d1800

Updated README

Browse files
README.md CHANGED
@@ -1,24 +1,38 @@
1
- ![Maturity level-0](https://img.shields.io/badge/Maturity%20Level-ML--0-red)
2
 
3
- # PROTAC-Degradation-Predictor
4
 
5
- Predicting PROTAC protein degradation activity via machine learning.
 
 
6
 
7
- ## Data Curation
8
 
9
- For data curation code, please refer to the code in the Jupyter notebooks [`data_curation.ipynb`](notebooks/data_curation.ipynb).
 
 
10
 
11
- ## Installing the Package
12
 
13
- To install the package, run the following command:
 
 
 
 
 
 
 
 
 
 
14
 
15
  ```bash
16
  pip install .
17
  ```
18
 
19
- ## Running the Package
20
 
21
- To run the package after installation, here is an example snippet:
22
 
23
  ```python
24
  import protac_degradation_predictor as pdp
@@ -33,16 +47,19 @@ active_protac = pdp.is_protac_active(
33
  e3_ligase,
34
  target_uniprot,
35
  cell_line,
36
- device='gpu', # Default to 'cpu'
37
  proba_threshold=0.5, # Default value
38
  )
39
 
40
  print(f'The given PROTAC is: {"active" if active_protac else "inactive"}')
41
  ```
42
 
43
- > If you're coming from my [thesis repo](https://github.com/ribesstefano/Machine-Learning-for-Predicting-Targeted-Protein-Degradation), I just wanted to create a separate and "less generic" repo for fast prototyping new ideas.
44
- > Stefano.
 
45
 
 
46
 
 
47
 
48
- > Why haven't you trained on more (i.e., the whole) data? We did, and we might just need _way_ more data to get better results...
 
1
+ <!-- ![Maturity level-0](https://img.shields.io/badge/Maturity%20Level-ML--0-red)
2
 
3
+ # PROTAC-Degradation-Predictor -->
4
 
5
+ <p align="center">
6
+ <img src="https://img.shields.io/badge/Maturity%20Level-ML--0-red" alt="Maturity level-0">
7
+ </p>
8
 
9
+ <h1 align="center">PROTAC-Degradation-Predictor</h1>
10
 
11
+ <p align="center">
12
+ A machine learning-based tool for predicting PROTAC protein degradation activity.
13
+ </p>
14
 
15
+ ## 📚 Table of Contents
16
 
17
+ - [Data Curation](#-data-curation)
18
+ - [Installation](#-installation)
19
+ - [Usage](#-usage)
20
+
21
+ ## 📝 Data Curation
22
+
23
+ The code for data curation can be found in the Jupyter notebook [`data_curation.ipynb`](notebooks/data_curation.ipynb).
24
+
25
+ ## 🚀 Installation
26
+
27
+ To install the package, open your terminal and run the following command:
28
 
29
  ```bash
30
  pip install .
31
  ```
32
 
33
+ ## 🎯 Usage
34
 
35
+ After installing the package, you can use it as follows:
36
 
37
  ```python
38
  import protac_degradation_predictor as pdp
 
47
  e3_ligase,
48
  target_uniprot,
49
  cell_line,
50
+ device='cuda', # Default to 'cpu'
51
  proba_threshold=0.5, # Default value
52
  )
53
 
54
  print(f'The given PROTAC is: {"active" if active_protac else "inactive"}')
55
  ```
56
 
57
+ This example demonstrates how to predict the activity of a PROTAC molecule. The `is_protac_active` function takes the SMILES string of the PROTAC, the E3 ligase, the UniProt ID of the target protein, and the cell line as inputs. It returns whether the PROTAC is active or not.
58
+
59
+ ## 📈 Training
60
 
61
+ The code for training the model can be found in the file [`run_experiments.py`](src/run_experiments.py).
62
 
63
+ ## 📜 License
64
 
65
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
notebooks/best_fingerprint_search.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/cell_type_embedding.ipynb CHANGED
@@ -869,6 +869,13 @@
869
  "unique_columns_ranking"
870
  ]
871
  },
 
 
 
 
 
 
 
872
  {
873
  "cell_type": "code",
874
  "execution_count": 12,
@@ -989,6 +996,47 @@
989
  " pickle.dump(cell2description, f)"
990
  ]
991
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
992
  {
993
  "cell_type": "markdown",
994
  "metadata": {},
@@ -1005,7 +1053,7 @@
1005
  },
1006
  {
1007
  "cell_type": "code",
1008
- "execution_count": 48,
1009
  "metadata": {},
1010
  "outputs": [],
1011
  "source": [
 
869
  "unique_columns_ranking"
870
  ]
871
  },
872
+ {
873
+ "cell_type": "markdown",
874
+ "metadata": {},
875
+ "source": [
876
+ "genome ancestry, karyotypic information, senescence, biotechnology, virology, caution, donor information, sequence variation, characteristics, transfected with, monoclonal antibody target, HLA typing, knockout cell, microsatellite instability, hierarchy (HI), breed/subspecies, derived from site, population, group, monoclonal antibody isotype, cell type, transformant, selected for resistance to, category (CA)."
877
+ ]
878
+ },
879
  {
880
  "cell_type": "code",
881
  "execution_count": 12,
 
996
  " pickle.dump(cell2description, f)"
997
  ]
998
  },
999
+ {
1000
+ "cell_type": "markdown",
1001
+ "metadata": {},
1002
+ "source": [
1003
+ "\\begin{figure*}[t!]\n",
1004
+ " \\centering\n",
1005
+ " \\begin{subfigure}{0.5\\textwidth}\n",
1006
+ " \\centering\n",
1007
+ " \\includegraphics[width=0.99\\columnwidth]{plots/pytorch_performance_Accuracy.pdf}\n",
1008
+ " \\caption{}\n",
1009
+ " \\label{fig:pytorch_accuracy}\n",
1010
+ " \\end{subfigure}%\n",
1011
+ " \\begin{subfigure}{0.5\\textwidth}\n",
1012
+ " \\centering\n",
1013
+ " \\includegraphics[width=0.99\\columnwidth]{plots/pytorch_performance_ROC AUC.pdf}\n",
1014
+ " \\caption{}\n",
1015
+ " \\label{fig:pytorch_roc_auc}\n",
1016
+ " \\end{subfigure}\\\\%\n",
1017
+ " \\begin{subfigure}{0.5\\textwidth}\n",
1018
+ " \\centering\n",
1019
+ " \\includegraphics[width=0.99\\columnwidth]{plots/pytorch_performance_F1 Score.pdf}\n",
1020
+ " \\caption{}\n",
1021
+ " \\label{fig:pytorch_f1_score}\n",
1022
+ " \\end{subfigure}%\n",
1023
+ " \\begin{subfigure}{0.5\\textwidth}\n",
1024
+ " \\centering\n",
1025
+ " \\includegraphics[width=0.99\\columnwidth]{plots/pytorch_performance_Precision.pdf}\n",
1026
+ " \\caption{}\n",
1027
+ " \\label{fig:pytorch_precision}\n",
1028
+ " \\end{subfigure}\\\\%\n",
1029
+ " \\begin{subfigure}{0.5\\textwidth}\n",
1030
+ " \\centering\n",
1031
+ " \\includegraphics[width=0.99\\columnwidth]{plots/pytorch_performance_Recall.pdf}\n",
1032
+ " \\caption{}\n",
1033
+ " \\label{fig:pytorch_recall}\n",
1034
+ " \\end{subfigure}%\n",
1035
+ " \\caption{Performance metrics of the proposed deep learning models. (a) ROC-AUC. (b) F1 score. (c) Precision. (d) Recall.}\n",
1036
+ " \\label{fig:pytorch_performance}\n",
1037
+ "\\end{figure*}"
1038
+ ]
1039
+ },
1040
  {
1041
  "cell_type": "markdown",
1042
  "metadata": {},
 
1053
  },
1054
  {
1055
  "cell_type": "code",
1056
+ "execution_count": 1,
1057
  "metadata": {},
1058
  "outputs": [],
1059
  "source": [
notebooks/plot_experimental_results.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/predict_unknown_protacs.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
protac_degradation_predictor/models/{best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt → best_model_n0_random-epoch=13-val_acc=0.83-val_roc_auc=0.841-v1.ckpt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:52e12060ef0d21f4eb4d84570c708e8c8502a1fbe6ebcbae1d86044e23b77708
3
- size 101362856
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:497045f4d5f3bcf859db7339f971d9a7c6c2881121fe8841f3c16d3d17f8c3fa
3
+ size 5127967
protac_degradation_predictor/models/best_model_n1_random-epoch=8-test_acc=0.78-test_roc_auc=0.851.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd7c412c10651dda9c528d57057f630c167a724f86f3dbd933122991406e0563
3
+ size 2565407
protac_degradation_predictor/models/best_model_n2_random-epoch=9-val_acc=0.80-val_roc_auc=0.841-v1.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a32c2eaa95d7f887912074426e092976a384c303b2e07cb05db2e8e5f1a48870
3
+ size 5127967
protac_degradation_predictor/models/cv_model_random_fold0-epoch=12-val_acc=0.86-val_roc_auc=0.905.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aaec6c36b21650fdca607be7978f38b5ee07e71a9d747c14402806de540148b7
3
+ size 2565087
protac_degradation_predictor/models/cv_model_random_fold1-epoch=16-val_acc=0.86-val_roc_auc=0.933.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59334fe5d4013c9575b827ab820bb2dcb9ab64b597cd6571557e18f3cc9707fe
3
+ size 2565407
protac_degradation_predictor/models/cv_model_random_fold2-epoch=16-val_acc=0.86-val_roc_auc=0.908.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe1ce77c1bf664a8ec6ee5a7a83c685b18fefc42ad0c3069dde2e88e77d0a07e
3
+ size 5128095
protac_degradation_predictor/models/cv_model_random_fold3-epoch=15-val_acc=0.89-val_roc_auc=0.930.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3606040e5077e0b3d9e00fd3ab8c94decb20cb84fc59a9aad0f55f3620f36b8a
3
+ size 2565087
protac_degradation_predictor/models/cv_model_random_fold4-epoch=15-val_acc=0.88-val_roc_auc=0.928.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e634ff4ebb89afeff3e31ff47d8215c8906e1133530630c0abcc938408dffc91
3
+ size 2565215
protac_degradation_predictor/optuna_utils.py CHANGED
@@ -104,7 +104,7 @@ def get_majority_vote_metrics(
104
  'test_roc_auc': AUROC(task='binary')(test_preds, y).item(),
105
  'test_precision': Precision(task='binary')(test_preds, y).item(),
106
  'test_recall': Recall(task='binary')(test_preds, y).item(),
107
- 'test_f1': F1Score(task='binary')(test_preds, y).item(),
108
  }
109
  return majority_vote_metrics
110
 
 
104
  'test_roc_auc': AUROC(task='binary')(test_preds, y).item(),
105
  'test_precision': Precision(task='binary')(test_preds, y).item(),
106
  'test_recall': Recall(task='binary')(test_preds, y).item(),
107
+ 'test_f1_score': F1Score(task='binary')(test_preds, y).item(),
108
  }
109
  return majority_vote_metrics
110
 
protac_degradation_predictor/protac_degradation_predictor.py CHANGED
@@ -1,6 +1,6 @@
1
  import pkg_resources
2
  import logging
3
- from typing import List
4
 
5
  from .pytorch_models import PROTAC_Model, load_model
6
  from .data_utils import (
@@ -20,12 +20,21 @@ def get_protac_active_proba(
20
  e3_ligase: str | List[str],
21
  target_uniprot: str | List[str],
22
  cell_line: str | List[str],
23
- device: str = 'cpu',
24
- ) -> bool:
 
 
 
 
 
 
 
 
 
25
 
26
- model_filename = 'best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt'
27
- ckpt_path = pkg_resources.resource_stream(__name__, f'models/{model_filename}')
28
- model = load_model(ckpt_path).to(device)
29
  protein2embedding = load_protein2embedding()
30
  cell2embedding = load_cell2embedding('data/cell2embedding.pkl')
31
 
@@ -60,32 +69,47 @@ def get_protac_active_proba(
60
  smiles_emb = [get_fingerprint(protac_smiles)]
61
 
62
  # Convert to torch tensors
63
- poi_emb = torch.tensor(poi_emb).to(device)
64
- e3_emb = torch.tensor(e3_emb).to(device)
65
- cell_emb = torch.tensor(cell_emb).to(device)
66
- smiles_emb = torch.tensor(smiles_emb).to(device)
67
-
68
- pred = model(
69
- poi_emb,
70
- e3_emb,
71
- cell_emb,
72
- smiles_emb,
73
- prescaled_embeddings=False, # Trigger automatic scaling
74
- )
75
 
76
- if isinstance(protac_smiles, list):
77
- return sigmoid(pred).detach().numpy().flatten()
78
- else:
79
- return sigmoid(pred).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  def is_protac_active(
83
- protac_smiles: str,
84
- e3_ligase: str,
85
- target_uniprot: str,
86
- cell_line: str,
87
  device: str = 'cpu',
88
  proba_threshold: float = 0.5,
 
 
89
  ) -> bool:
90
  """ Predict whether a PROTAC is active or not.
91
 
@@ -106,5 +130,9 @@ def is_protac_active(
106
  target_uniprot,
107
  cell_line,
108
  device,
 
109
  )
110
- return pred > proba_threshold
 
 
 
 
1
  import pkg_resources
2
  import logging
3
+ from typing import List, Literal, Dict
4
 
5
  from .pytorch_models import PROTAC_Model, load_model
6
  from .data_utils import (
 
20
  e3_ligase: str | List[str],
21
  target_uniprot: str | List[str],
22
  cell_line: str | List[str],
23
+ device: Literal['cpu', 'cuda'] = 'cpu',
24
+ use_models_from_cv: bool = False,
25
+ ) -> Dict[str, np.ndarray]:
26
+ """ Predict the probability of a PROTAC being active.
27
+
28
+ Args:
29
+ protac_smiles (str | List[str]): The SMILES of the PROTAC.
30
+ e3_ligase (str | List[str]): The Uniprot ID of the E3 ligase.
31
+ target_uniprot (str | List[str]): The Uniprot ID of the target protein.
32
+ cell_line (str | List[str]): The cell line identifier.
33
+ device (str): The device to run the model on.
34
 
35
+ Returns:
36
+ Dict[str, np.ndarray]: The predictions of the model.
37
+ """
38
  protein2embedding = load_protein2embedding()
39
  cell2embedding = load_cell2embedding('data/cell2embedding.pkl')
40
 
 
69
  smiles_emb = [get_fingerprint(protac_smiles)]
70
 
71
  # Convert to torch tensors
72
+ poi_emb = torch.tensor(np.array(poi_emb)).to(device)
73
+ e3_emb = torch.tensor(np.array(e3_emb)).to(device)
74
+ cell_emb = torch.tensor(np.array(cell_emb)).to(device)
75
+ smiles_emb = torch.tensor(np.array(smiles_emb)).float().to(device)
 
 
 
 
 
 
 
 
76
 
77
+ models = {}
78
+ model_to_load = 'best_model' if not use_models_from_cv else 'cv_model'
79
+ # Load all models in pkg_resources that start with 'model_to_load'
80
+ for model_filename in pkg_resources.resource_listdir(__name__, 'models'):
81
+ if model_filename.startswith(model_to_load):
82
+ ckpt_path = pkg_resources.resource_stream(__name__, f'models/{model_filename}')
83
+ models[ckpt_path] = load_model(ckpt_path).to(device)
84
+
85
+ # Average the predictions of all models
86
+ preds = {}
87
+ for ckpt_path, model in models.items():
88
+ pred = model(
89
+ poi_emb,
90
+ e3_emb,
91
+ cell_emb,
92
+ smiles_emb,
93
+ prescaled_embeddings=False, # Normalization performed by the model
94
+ )
95
+ preds[ckpt_path] = sigmoid(pred).detach().numpy().flatten()
96
+ axis = 1 if isinstance(protac_smiles, list) else None
97
+ return {
98
+ 'preds': np.array(list(preds.values())),
99
+ 'mean': np.mean(list(preds.values()), axis=axis),
100
+ 'majority_vote': np.mean(list(preds.values()), axis=axis) > 0.5,
101
+ }
102
 
103
 
104
  def is_protac_active(
105
+ protac_smiles: str | List[str],
106
+ e3_ligase: str | List[str],
107
+ target_uniprot: str | List[str],
108
+ cell_line: str | List[str],
109
  device: str = 'cpu',
110
  proba_threshold: float = 0.5,
111
+ use_majority_vote: bool = False,
112
+ use_models_from_cv: bool = False,
113
  ) -> bool:
114
  """ Predict whether a PROTAC is active or not.
115
 
 
130
  target_uniprot,
131
  cell_line,
132
  device,
133
+ use_models_from_cv,
134
  )
135
+ if use_majority_vote:
136
+ return pred['majority_vote']
137
+ else:
138
+ return pred['mean'] > proba_threshold
protac_degradation_predictor/pytorch_models.py CHANGED
@@ -53,14 +53,6 @@ class PROTAC_Predictor(nn.Module):
53
  disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles'
54
  """
55
  super().__init__()
56
- self.poi_emb_dim = poi_emb_dim
57
- self.e3_emb_dim = e3_emb_dim
58
- self.cell_emb_dim = cell_emb_dim
59
- self.smiles_emb_dim = smiles_emb_dim
60
- self.hidden_dim = hidden_dim
61
- self.join_embeddings = join_embeddings
62
- self.use_batch_norm = use_batch_norm
63
- self.disabled_embeddings = disabled_embeddings
64
  # Set our init args as class attributes
65
  self.__dict__.update(locals())
66
 
@@ -126,12 +118,24 @@ class PROTAC_Predictor(nn.Module):
126
  else:
127
  if 'poi' not in self.disabled_embeddings:
128
  embeddings.append(self.poi_fc(poi_emb))
 
 
 
129
  if 'e3' not in self.disabled_embeddings:
130
  embeddings.append(self.e3_fc(e3_emb))
 
 
 
131
  if 'cell' not in self.disabled_embeddings:
132
  embeddings.append(self.cell_fc(cell_emb))
 
 
 
133
  if 'smiles' not in self.disabled_embeddings:
134
  embeddings.append(self.smiles_emb(smiles_emb))
 
 
 
135
  if self.join_embeddings == 'concat':
136
  x = torch.cat(embeddings, dim=1)
137
  elif self.join_embeddings == 'sum':
@@ -140,6 +144,8 @@ class PROTAC_Predictor(nn.Module):
140
  x = torch.sum(embeddings, dim=1)
141
  else:
142
  x = embeddings[0]
 
 
143
  x = F.relu(self.fc1(x))
144
  x = self.bnorm(x) if self.use_batch_norm else self.self.dropout(x)
145
  x = self.fc3(x)
@@ -185,19 +191,6 @@ class PROTAC_Model(pl.LightningModule):
185
  apply_scaling (bool): Whether to apply scaling to the embeddings
186
  """
187
  super().__init__()
188
- self.poi_emb_dim = poi_emb_dim
189
- self.e3_emb_dim = e3_emb_dim
190
- self.cell_emb_dim = cell_emb_dim
191
- self.smiles_emb_dim = smiles_emb_dim
192
- self.hidden_dim = hidden_dim
193
- self.batch_size = batch_size
194
- self.learning_rate = learning_rate
195
- self.join_embeddings = join_embeddings
196
- self.train_dataset = train_dataset
197
- self.val_dataset = val_dataset
198
- self.test_dataset = test_dataset
199
- self.disabled_embeddings = disabled_embeddings
200
- self.apply_scaling = apply_scaling
201
  # Set our init args as class attributes
202
  self.__dict__.update(locals()) # Add arguments as attributes
203
  # Save the arguments passed to init
@@ -265,6 +258,7 @@ class PROTAC_Model(pl.LightningModule):
265
  self,
266
  tensor: torch.Tensor,
267
  scaler: StandardScaler,
 
268
  ) -> torch.Tensor:
269
  """Scale a tensor using a scaler. This is done to avoid using numpy
270
  arrays (and stay on the same device).
@@ -280,7 +274,7 @@ class PROTAC_Model(pl.LightningModule):
280
  if scaler.with_mean:
281
  tensor -= torch.tensor(scaler.mean_, dtype=tensor.dtype, device=tensor.device)
282
  if scaler.with_std:
283
- tensor /= torch.tensor(scaler.scale_, dtype=tensor.dtype, device=tensor.device)
284
  return tensor
285
 
286
  def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=True):
@@ -300,6 +294,14 @@ class PROTAC_Model(pl.LightningModule):
300
  e3_emb = self.scale_tensor(e3_emb, self.scalers['E3 Ligase Uniprot'])
301
  cell_emb = self.scale_tensor(cell_emb, self.scalers['Cell Line Identifier'])
302
  smiles_emb = self.scale_tensor(smiles_emb, self.scalers['Smiles'])
 
 
 
 
 
 
 
 
303
  return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
304
 
305
  def step(self, batch, batch_idx, stage):
@@ -624,5 +626,4 @@ def load_model(
624
  # with other datasets...
625
  # if model.apply_scaling:
626
  # model.apply_scalers()
627
- model.eval()
628
- return model
 
53
  disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles'
54
  """
55
  super().__init__()
 
 
 
 
 
 
 
 
56
  # Set our init args as class attributes
57
  self.__dict__.update(locals())
58
 
 
118
  else:
119
  if 'poi' not in self.disabled_embeddings:
120
  embeddings.append(self.poi_fc(poi_emb))
121
+ if torch.isnan(embeddings[-1]).any():
122
+ raise ValueError("NaN values found in POI embeddings.")
123
+
124
  if 'e3' not in self.disabled_embeddings:
125
  embeddings.append(self.e3_fc(e3_emb))
126
+ if torch.isnan(embeddings[-1]).any():
127
+ raise ValueError("NaN values found in E3 embeddings.")
128
+
129
  if 'cell' not in self.disabled_embeddings:
130
  embeddings.append(self.cell_fc(cell_emb))
131
+ if torch.isnan(embeddings[-1]).any():
132
+ raise ValueError("NaN values found in cell embeddings.")
133
+
134
  if 'smiles' not in self.disabled_embeddings:
135
  embeddings.append(self.smiles_emb(smiles_emb))
136
+ if torch.isnan(embeddings[-1]).any():
137
+ raise ValueError("NaN values found in SMILES embeddings.")
138
+
139
  if self.join_embeddings == 'concat':
140
  x = torch.cat(embeddings, dim=1)
141
  elif self.join_embeddings == 'sum':
 
144
  x = torch.sum(embeddings, dim=1)
145
  else:
146
  x = embeddings[0]
147
+ if torch.isnan(x).any():
148
+ raise ValueError("NaN values found in sum of softmax-ed embeddings.")
149
  x = F.relu(self.fc1(x))
150
  x = self.bnorm(x) if self.use_batch_norm else self.self.dropout(x)
151
  x = self.fc3(x)
 
191
  apply_scaling (bool): Whether to apply scaling to the embeddings
192
  """
193
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  # Set our init args as class attributes
195
  self.__dict__.update(locals()) # Add arguments as attributes
196
  # Save the arguments passed to init
 
258
  self,
259
  tensor: torch.Tensor,
260
  scaler: StandardScaler,
261
+ alpha: float = 1e-10,
262
  ) -> torch.Tensor:
263
  """Scale a tensor using a scaler. This is done to avoid using numpy
264
  arrays (and stay on the same device).
 
274
  if scaler.with_mean:
275
  tensor -= torch.tensor(scaler.mean_, dtype=tensor.dtype, device=tensor.device)
276
  if scaler.with_std:
277
+ tensor /= torch.tensor(scaler.scale_, dtype=tensor.dtype, device=tensor.device) + alpha
278
  return tensor
279
 
280
  def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=True):
 
294
  e3_emb = self.scale_tensor(e3_emb, self.scalers['E3 Ligase Uniprot'])
295
  cell_emb = self.scale_tensor(cell_emb, self.scalers['Cell Line Identifier'])
296
  smiles_emb = self.scale_tensor(smiles_emb, self.scalers['Smiles'])
297
+ if torch.isnan(poi_emb).any():
298
+ raise ValueError("NaN values found in POI embeddings.")
299
+ if torch.isnan(e3_emb).any():
300
+ raise ValueError("NaN values found in E3 embeddings.")
301
+ if torch.isnan(cell_emb).any():
302
+ raise ValueError("NaN values found in cell embeddings.")
303
+ if torch.isnan(smiles_emb).any():
304
+ raise ValueError("NaN values found in SMILES embeddings.")
305
  return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
306
 
307
  def step(self, batch, batch_idx, stage):
 
626
  # with other datasets...
627
  # if model.apply_scaling:
628
  # model.apply_scalers()
629
+ return model.eval()
 
src/plot_experiment_results.py CHANGED
@@ -331,7 +331,6 @@ def main():
331
  ]),
332
  }
333
 
334
-
335
  for split_type in ['random', 'tanimoto', 'uniprot']:
336
  split_metrics = []
337
  for i in range(n_models_for_test):
 
331
  ]),
332
  }
333
 
 
334
  for split_type in ['random', 'tanimoto', 'uniprot']:
335
  split_metrics = []
336
  for i in range(n_models_for_test):
src/run_experiments.py CHANGED
@@ -8,7 +8,6 @@ from typing import Literal
8
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9
 
10
  import protac_degradation_predictor as pdp
11
- from protac_degradation_predictor.optuna_utils import get_dataframe_stats
12
 
13
  import pytorch_lightning as pl
14
  from rdkit import Chem
 
8
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9
 
10
  import protac_degradation_predictor as pdp
 
11
 
12
  import pytorch_lightning as pl
13
  from rdkit import Chem