ribesstefano
commited on
Commit
•
15216c3
1
Parent(s):
0171744
Working loading and predicting with pretrained model + Added scaling in "pure" pytorch during forwarding
Browse files- protac_degradation_predictor/__init__.py +4 -0
- protac_degradation_predictor/models/best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt +3 -0
- protac_degradation_predictor/protac_degradation_predictor.py +45 -22
- protac_degradation_predictor/pytorch_models.py +41 -2
- setup.py +1 -1
- tests/test_degradation_prediction.py +45 -0
- tests/test_pytorch_model.py +13 -6
protac_degradation_predictor/__init__.py
CHANGED
@@ -16,6 +16,10 @@ from .optuna_utils import (
|
|
16 |
hyperparameter_tuning_and_training,
|
17 |
hyperparameter_tuning_and_training_sklearn,
|
18 |
)
|
|
|
|
|
|
|
|
|
19 |
|
20 |
__version__ = "0.0.1"
|
21 |
__author__ = "Stefano Ribes"
|
|
|
16 |
hyperparameter_tuning_and_training,
|
17 |
hyperparameter_tuning_and_training_sklearn,
|
18 |
)
|
19 |
+
from .protac_degradation_predictor import (
|
20 |
+
get_protac_active_proba,
|
21 |
+
is_protac_active,
|
22 |
+
)
|
23 |
|
24 |
__version__ = "0.0.1"
|
25 |
__author__ = "Stefano Ribes"
|
protac_degradation_predictor/models/best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:52e12060ef0d21f4eb4d84570c708e8c8502a1fbe6ebcbae1d86044e23b77708
|
3 |
+
size 101362856
|
protac_degradation_predictor/protac_degradation_predictor.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import pkg_resources
|
2 |
import logging
|
|
|
3 |
|
4 |
from .pytorch_models import PROTAC_Model, load_model
|
5 |
from .data_utils import (
|
@@ -15,37 +16,48 @@ from torch import sigmoid
|
|
15 |
|
16 |
|
17 |
def get_protac_active_proba(
|
18 |
-
protac_smiles: str,
|
19 |
-
e3_ligase: str,
|
20 |
-
target_uniprot: str,
|
21 |
-
cell_line: str,
|
22 |
device: str = 'cpu',
|
23 |
) -> bool:
|
24 |
-
|
|
|
|
|
25 |
model = load_model(ckpt_path).to(device)
|
26 |
protein2embedding = load_protein2embedding()
|
27 |
-
cell2embedding = load_cell2embedding()
|
28 |
|
29 |
# Setup default embeddings
|
30 |
-
if e3_ligase not in config.e3_ligase2uniprot:
|
31 |
-
available_e3_ligases = ', '.join(list(config.e3_ligase2uniprot.keys()))
|
32 |
-
logging.warning(f"The E3 ligase {e3_ligase} is not in the database. Using the default E3 ligase. Available E3 ligases are: {available_e3_ligases}")
|
33 |
-
if target_uniprot not in protein2embedding:
|
34 |
-
logging.warning(f"The target protein {target_uniprot} is not in the database. Using the default target protein.")
|
35 |
-
if cell_line not in load_cell2embedding():
|
36 |
-
logging.warning(f"The cell line {cell_line} is not in the database. Using the default cell line.")
|
37 |
-
|
38 |
default_protein_emb = np.zeros(config.protein_embedding_size)
|
39 |
default_cell_emb = np.zeros(config.cell_embedding_size)
|
40 |
-
|
41 |
# Convert the E3 ligase to Uniprot ID
|
42 |
-
|
|
|
|
|
|
|
43 |
|
44 |
# Get the embeddings
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
# Convert to torch tensors
|
51 |
poi_emb = torch.tensor(poi_emb).to(device)
|
@@ -53,7 +65,18 @@ def get_protac_active_proba(
|
|
53 |
cell_emb = torch.tensor(cell_emb).to(device)
|
54 |
smiles_emb = torch.tensor(smiles_emb).to(device)
|
55 |
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
|
59 |
def is_protac_active(
|
@@ -84,4 +107,4 @@ def is_protac_active(
|
|
84 |
cell_line,
|
85 |
device,
|
86 |
)
|
87 |
-
return
|
|
|
1 |
import pkg_resources
|
2 |
import logging
|
3 |
+
from typing import List
|
4 |
|
5 |
from .pytorch_models import PROTAC_Model, load_model
|
6 |
from .data_utils import (
|
|
|
16 |
|
17 |
|
18 |
def get_protac_active_proba(
|
19 |
+
protac_smiles: str | List[str],
|
20 |
+
e3_ligase: str | List[str],
|
21 |
+
target_uniprot: str | List[str],
|
22 |
+
cell_line: str | List[str],
|
23 |
device: str = 'cpu',
|
24 |
) -> bool:
|
25 |
+
|
26 |
+
model_filename = 'best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt'
|
27 |
+
ckpt_path = pkg_resources.resource_stream(__name__, f'models/{model_filename}')
|
28 |
model = load_model(ckpt_path).to(device)
|
29 |
protein2embedding = load_protein2embedding()
|
30 |
+
cell2embedding = load_cell2embedding('data/cell2embedding.pkl')
|
31 |
|
32 |
# Setup default embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
default_protein_emb = np.zeros(config.protein_embedding_size)
|
34 |
default_cell_emb = np.zeros(config.cell_embedding_size)
|
35 |
+
|
36 |
# Convert the E3 ligase to Uniprot ID
|
37 |
+
if isinstance(e3_ligase, list):
|
38 |
+
e3_ligase_uniprot = [config.e3_ligase2uniprot.get(e3, '') for e3 in e3_ligase]
|
39 |
+
else:
|
40 |
+
e3_ligase_uniprot = config.e3_ligase2uniprot.get(e3_ligase, '')
|
41 |
|
42 |
# Get the embeddings
|
43 |
+
if isinstance(protac_smiles, list):
|
44 |
+
# TODO: Add warning on missing entries?
|
45 |
+
poi_emb = [protein2embedding.get(t, default_protein_emb) for t in target_uniprot]
|
46 |
+
e3_emb = [protein2embedding.get(e3, default_protein_emb) for e3 in e3_ligase_uniprot]
|
47 |
+
cell_emb = [cell2embedding.get(cell_line, default_cell_emb) for cell_line in cell_line]
|
48 |
+
smiles_emb = [get_fingerprint(protac_smiles) for protac_smiles in protac_smiles]
|
49 |
+
else:
|
50 |
+
if e3_ligase not in config.e3_ligase2uniprot:
|
51 |
+
available_e3_ligases = ', '.join(list(config.e3_ligase2uniprot.keys()))
|
52 |
+
logging.warning(f"The E3 ligase {e3_ligase} is not in the database. Using the default E3 ligase. Available E3 ligases are: {available_e3_ligases}")
|
53 |
+
if target_uniprot not in protein2embedding:
|
54 |
+
logging.warning(f"The target protein {target_uniprot} is not in the database. Using the default target protein.")
|
55 |
+
if cell_line not in cell2embedding:
|
56 |
+
logging.warning(f"The cell line {cell_line} is not in the database. Using the default cell line.")
|
57 |
+
poi_emb = [protein2embedding.get(target_uniprot, default_protein_emb)]
|
58 |
+
e3_emb = [protein2embedding.get(e3_ligase_uniprot, default_protein_emb)]
|
59 |
+
cell_emb = [cell2embedding.get(cell_line, default_cell_emb)]
|
60 |
+
smiles_emb = [get_fingerprint(protac_smiles)]
|
61 |
|
62 |
# Convert to torch tensors
|
63 |
poi_emb = torch.tensor(poi_emb).to(device)
|
|
|
65 |
cell_emb = torch.tensor(cell_emb).to(device)
|
66 |
smiles_emb = torch.tensor(smiles_emb).to(device)
|
67 |
|
68 |
+
pred = model(
|
69 |
+
poi_emb,
|
70 |
+
e3_emb,
|
71 |
+
cell_emb,
|
72 |
+
smiles_emb,
|
73 |
+
prescaled_embeddings=False, # Trigger automatic scaling
|
74 |
+
)
|
75 |
+
|
76 |
+
if isinstance(protac_smiles, list):
|
77 |
+
return sigmoid(pred).detach().numpy().flatten()
|
78 |
+
else:
|
79 |
+
return sigmoid(pred).item()
|
80 |
|
81 |
|
82 |
def is_protac_active(
|
|
|
107 |
cell_line,
|
108 |
device,
|
109 |
)
|
110 |
+
return pred > proba_threshold
|
protac_degradation_predictor/pytorch_models.py
CHANGED
@@ -23,6 +23,7 @@ from torchmetrics import (
|
|
23 |
MetricCollection,
|
24 |
)
|
25 |
from imblearn.over_sampling import SMOTE
|
|
|
26 |
|
27 |
|
28 |
class PROTAC_Predictor(nn.Module):
|
@@ -239,8 +240,46 @@ class PROTAC_Model(pl.LightningModule):
|
|
239 |
self.val_dataset.apply_scaling(self.scalers, use_single_scaler)
|
240 |
if self.test_dataset:
|
241 |
self.test_dataset.apply_scaling(self.scalers, use_single_scaler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
|
245 |
|
246 |
def step(self, batch, batch_idx, stage):
|
@@ -451,7 +490,7 @@ def train_model(
|
|
451 |
),
|
452 |
pl.callbacks.EarlyStopping(
|
453 |
monitor='val_loss',
|
454 |
-
patience=5
|
455 |
mode='min',
|
456 |
verbose=False,
|
457 |
),
|
|
|
23 |
MetricCollection,
|
24 |
)
|
25 |
from imblearn.over_sampling import SMOTE
|
26 |
+
from sklearn.preprocessing import StandardScaler
|
27 |
|
28 |
|
29 |
class PROTAC_Predictor(nn.Module):
|
|
|
240 |
self.val_dataset.apply_scaling(self.scalers, use_single_scaler)
|
241 |
if self.test_dataset:
|
242 |
self.test_dataset.apply_scaling(self.scalers, use_single_scaler)
|
243 |
+
|
244 |
+
def scale_tensor(
|
245 |
+
self,
|
246 |
+
tensor: torch.Tensor,
|
247 |
+
scaler: StandardScaler,
|
248 |
+
) -> torch.Tensor:
|
249 |
+
"""Scale a tensor using a scaler. This is done to avoid using numpy
|
250 |
+
arrays (and stay on the same device).
|
251 |
+
|
252 |
+
Args:
|
253 |
+
tensor (torch.Tensor): The tensor to scale.
|
254 |
+
scaler (StandardScaler): The scaler to use.
|
255 |
|
256 |
+
Returns:
|
257 |
+
torch.Tensor: The scaled tensor.
|
258 |
+
"""
|
259 |
+
tensor = tensor.float()
|
260 |
+
if scaler.with_mean:
|
261 |
+
tensor -= torch.tensor(scaler.mean_, dtype=tensor.dtype, device=tensor.device)
|
262 |
+
if scaler.with_std:
|
263 |
+
tensor /= torch.tensor(scaler.scale_, dtype=tensor.dtype, device=tensor.device)
|
264 |
+
return tensor
|
265 |
+
|
266 |
+
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=True):
|
267 |
+
if not prescaled_embeddings:
|
268 |
+
if self.apply_scaling:
|
269 |
+
if self.join_embeddings == 'beginning':
|
270 |
+
embeddings = self.scale_tensor(
|
271 |
+
torch.hstack([smiles_emb, poi_emb, e3_emb, cell_emb]),
|
272 |
+
self.scalers,
|
273 |
+
)
|
274 |
+
smiles_emb = embeddings[:, :self.smiles_emb_dim]
|
275 |
+
poi_emb = embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.poi_emb_dim]
|
276 |
+
e3_emb = embeddings[:, self.smiles_emb_dim+self.poi_emb_dim:self.smiles_emb_dim+2*self.poi_emb_dim]
|
277 |
+
cell_emb = embeddings[:, -self.cell_emb_dim:]
|
278 |
+
else:
|
279 |
+
poi_emb = self.scale_tensor(poi_emb, self.scalers['Uniprot'])
|
280 |
+
e3_emb = self.scale_tensor(e3_emb, self.scalers['E3 Ligase Uniprot'])
|
281 |
+
cell_emb = self.scale_tensor(cell_emb, self.scalers['Cell Line Identifier'])
|
282 |
+
smiles_emb = self.scale_tensor(smiles_emb, self.scalers['Smiles'])
|
283 |
return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
|
284 |
|
285 |
def step(self, batch, batch_idx, stage):
|
|
|
490 |
),
|
491 |
pl.callbacks.EarlyStopping(
|
492 |
monitor='val_loss',
|
493 |
+
patience=10, # Original: 5
|
494 |
mode='min',
|
495 |
verbose=False,
|
496 |
),
|
setup.py
CHANGED
@@ -17,5 +17,5 @@ setuptools.setup(
|
|
17 |
"Operating System :: OS Independent",
|
18 |
],
|
19 |
include_package_data=True,
|
20 |
-
package_data={"": ["data/*.h5", "data/*.pkl", "data/*.csv"]},
|
21 |
)
|
|
|
17 |
"Operating System :: OS Independent",
|
18 |
],
|
19 |
include_package_data=True,
|
20 |
+
package_data={"": ["data/*.h5", "data/*.pkl", "data/*.csv", "models/*.ckpt"]},
|
21 |
)
|
tests/test_degradation_prediction.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import logging
|
5 |
+
|
6 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
7 |
+
|
8 |
+
from protac_degradation_predictor import (
|
9 |
+
get_protac_active_proba,
|
10 |
+
is_protac_active,
|
11 |
+
)
|
12 |
+
|
13 |
+
import torch
|
14 |
+
|
15 |
+
|
16 |
+
def test_active_proba():
|
17 |
+
protac_smiles = 'Cc1ncsc1-c1ccc([C@H](C)NC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](NC(=O)CC(=O)N2CCN(CC[C@H](CSc3ccccc3)Nc3ccc(S(=O)(=O)NC(=O)c4ccc(N5CCN(CC6=C(c7ccc(Cl)cc7)CCC(C)(C)C6)CC5)cc4)cc3S(=O)(=O)C(F)(F)F)CC2)C(C)(C)C)cc1'
|
18 |
+
e3_ligase = 'VHL'
|
19 |
+
target_uniprot = 'Q07817'
|
20 |
+
cell_line = 'MOLT-4'
|
21 |
+
device = 'cpu'
|
22 |
+
|
23 |
+
active_prob = get_protac_active_proba(
|
24 |
+
protac_smiles=protac_smiles,
|
25 |
+
e3_ligase=e3_ligase,
|
26 |
+
target_uniprot=target_uniprot,
|
27 |
+
cell_line=cell_line,
|
28 |
+
device=device,
|
29 |
+
)
|
30 |
+
|
31 |
+
print(f'Active probability: {active_prob} (CPU)')
|
32 |
+
|
33 |
+
active_prob = get_protac_active_proba(
|
34 |
+
protac_smiles=[protac_smiles] * 16,
|
35 |
+
e3_ligase=[e3_ligase] * 16,
|
36 |
+
target_uniprot=[target_uniprot] * 16,
|
37 |
+
cell_line=[cell_line] * 16,
|
38 |
+
device='gpu' if torch.cuda.is_available() else 'cpu',
|
39 |
+
)
|
40 |
+
|
41 |
+
print(f'Active probability: {active_prob} (GPU)')
|
42 |
+
|
43 |
+
|
44 |
+
def test_is_protac_active():
|
45 |
+
pass
|
tests/test_pytorch_model.py
CHANGED
@@ -40,19 +40,21 @@ def test_protac_predictor():
|
|
40 |
def test_load_model(caplog):
|
41 |
caplog.set_level(logging.WARNING)
|
42 |
|
|
|
|
|
43 |
model = PROTAC_Model.load_from_checkpoint(
|
44 |
-
|
45 |
map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
|
46 |
)
|
47 |
# apply_scaling: true
|
48 |
# batch_size: 8
|
49 |
# cell_emb_dim: 768
|
50 |
# disabled_embeddings: []
|
51 |
-
# dropout: 0.
|
52 |
# e3_emb_dim: 1024
|
53 |
# hidden_dim: 768
|
54 |
# join_embeddings: concat
|
55 |
-
# learning_rate:
|
56 |
# poi_emb_dim: 1024
|
57 |
# smiles_emb_dim: 224
|
58 |
assert model.hidden_dim == 768
|
@@ -61,20 +63,25 @@ def test_load_model(caplog):
|
|
61 |
assert model.e3_emb_dim == 1024
|
62 |
assert model.cell_emb_dim == 768
|
63 |
assert model.batch_size == 8
|
64 |
-
assert model.learning_rate ==
|
65 |
-
assert model.dropout == 0.
|
66 |
assert model.join_embeddings == 'concat'
|
67 |
assert model.disabled_embeddings == []
|
68 |
assert model.apply_scaling == True
|
|
|
69 |
|
70 |
|
71 |
def test_checkpoint_file():
|
|
|
72 |
checkpoint = torch.load(
|
73 |
-
|
74 |
map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
|
75 |
)
|
76 |
print(checkpoint.keys())
|
77 |
print(checkpoint["hyper_parameters"])
|
78 |
print([k for k, v in checkpoint["state_dict"].items()])
|
|
|
|
|
|
|
79 |
|
80 |
pytest.main()
|
|
|
40 |
def test_load_model(caplog):
|
41 |
caplog.set_level(logging.WARNING)
|
42 |
|
43 |
+
model_filename = 'data/best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt'
|
44 |
+
|
45 |
model = PROTAC_Model.load_from_checkpoint(
|
46 |
+
model_filename,
|
47 |
map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
|
48 |
)
|
49 |
# apply_scaling: true
|
50 |
# batch_size: 8
|
51 |
# cell_emb_dim: 768
|
52 |
# disabled_embeddings: []
|
53 |
+
# dropout: 0.11257777663560328
|
54 |
# e3_emb_dim: 1024
|
55 |
# hidden_dim: 768
|
56 |
# join_embeddings: concat
|
57 |
+
# learning_rate: 1.843233973932415e-05
|
58 |
# poi_emb_dim: 1024
|
59 |
# smiles_emb_dim: 224
|
60 |
assert model.hidden_dim == 768
|
|
|
63 |
assert model.e3_emb_dim == 1024
|
64 |
assert model.cell_emb_dim == 768
|
65 |
assert model.batch_size == 8
|
66 |
+
assert model.learning_rate == 1.843233973932415e-05
|
67 |
+
assert model.dropout == 0.11257777663560328
|
68 |
assert model.join_embeddings == 'concat'
|
69 |
assert model.disabled_embeddings == []
|
70 |
assert model.apply_scaling == True
|
71 |
+
print(model.scalers)
|
72 |
|
73 |
|
74 |
def test_checkpoint_file():
|
75 |
+
model_filename = 'data/best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt'
|
76 |
checkpoint = torch.load(
|
77 |
+
model_filename,
|
78 |
map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
|
79 |
)
|
80 |
print(checkpoint.keys())
|
81 |
print(checkpoint["hyper_parameters"])
|
82 |
print([k for k, v in checkpoint["state_dict"].items()])
|
83 |
+
import pickle
|
84 |
+
|
85 |
+
print(pickle.loads(checkpoint['scalers']))
|
86 |
|
87 |
pytest.main()
|