ribesstefano commited on
Commit
15216c3
1 Parent(s): 0171744

Working loading and predicting with pretrained model + Added scaling in "pure" pytorch during forwarding

Browse files
protac_degradation_predictor/__init__.py CHANGED
@@ -16,6 +16,10 @@ from .optuna_utils import (
16
  hyperparameter_tuning_and_training,
17
  hyperparameter_tuning_and_training_sklearn,
18
  )
 
 
 
 
19
 
20
  __version__ = "0.0.1"
21
  __author__ = "Stefano Ribes"
 
16
  hyperparameter_tuning_and_training,
17
  hyperparameter_tuning_and_training_sklearn,
18
  )
19
+ from .protac_degradation_predictor import (
20
+ get_protac_active_proba,
21
+ is_protac_active,
22
+ )
23
 
24
  __version__ = "0.0.1"
25
  __author__ = "Stefano Ribes"
protac_degradation_predictor/models/best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52e12060ef0d21f4eb4d84570c708e8c8502a1fbe6ebcbae1d86044e23b77708
3
+ size 101362856
protac_degradation_predictor/protac_degradation_predictor.py CHANGED
@@ -1,5 +1,6 @@
1
  import pkg_resources
2
  import logging
 
3
 
4
  from .pytorch_models import PROTAC_Model, load_model
5
  from .data_utils import (
@@ -15,37 +16,48 @@ from torch import sigmoid
15
 
16
 
17
  def get_protac_active_proba(
18
- protac_smiles: str,
19
- e3_ligase: str,
20
- target_uniprot: str,
21
- cell_line: str,
22
  device: str = 'cpu',
23
  ) -> bool:
24
- ckpt_path = pkg_resources.resource_stream(__name__, 'data/model.ckpt')
 
 
25
  model = load_model(ckpt_path).to(device)
26
  protein2embedding = load_protein2embedding()
27
- cell2embedding = load_cell2embedding()
28
 
29
  # Setup default embeddings
30
- if e3_ligase not in config.e3_ligase2uniprot:
31
- available_e3_ligases = ', '.join(list(config.e3_ligase2uniprot.keys()))
32
- logging.warning(f"The E3 ligase {e3_ligase} is not in the database. Using the default E3 ligase. Available E3 ligases are: {available_e3_ligases}")
33
- if target_uniprot not in protein2embedding:
34
- logging.warning(f"The target protein {target_uniprot} is not in the database. Using the default target protein.")
35
- if cell_line not in load_cell2embedding():
36
- logging.warning(f"The cell line {cell_line} is not in the database. Using the default cell line.")
37
-
38
  default_protein_emb = np.zeros(config.protein_embedding_size)
39
  default_cell_emb = np.zeros(config.cell_embedding_size)
40
-
41
  # Convert the E3 ligase to Uniprot ID
42
- e3_ligase_uniprot = config.e3_ligase2uniprot.get(e3_ligase, '')
 
 
 
43
 
44
  # Get the embeddings
45
- poi_emb = protein2embedding.get(target_uniprot, default_protein_emb)
46
- e3_emb = protein2embedding.get(e3_ligase_uniprot, default_protein_emb)
47
- cell_emb = cell2embedding.get(cell_line, default_cell_emb)
48
- smiles_emb = get_fingerprint(protac_smiles)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # Convert to torch tensors
51
  poi_emb = torch.tensor(poi_emb).to(device)
@@ -53,7 +65,18 @@ def get_protac_active_proba(
53
  cell_emb = torch.tensor(cell_emb).to(device)
54
  smiles_emb = torch.tensor(smiles_emb).to(device)
55
 
56
- return model(poi_emb, e3_emb, cell_emb, smiles_emb).item()
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  def is_protac_active(
@@ -84,4 +107,4 @@ def is_protac_active(
84
  cell_line,
85
  device,
86
  )
87
- return sigmoid(pred) > proba_threshold
 
1
  import pkg_resources
2
  import logging
3
+ from typing import List
4
 
5
  from .pytorch_models import PROTAC_Model, load_model
6
  from .data_utils import (
 
16
 
17
 
18
  def get_protac_active_proba(
19
+ protac_smiles: str | List[str],
20
+ e3_ligase: str | List[str],
21
+ target_uniprot: str | List[str],
22
+ cell_line: str | List[str],
23
  device: str = 'cpu',
24
  ) -> bool:
25
+
26
+ model_filename = 'best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt'
27
+ ckpt_path = pkg_resources.resource_stream(__name__, f'models/{model_filename}')
28
  model = load_model(ckpt_path).to(device)
29
  protein2embedding = load_protein2embedding()
30
+ cell2embedding = load_cell2embedding('data/cell2embedding.pkl')
31
 
32
  # Setup default embeddings
 
 
 
 
 
 
 
 
33
  default_protein_emb = np.zeros(config.protein_embedding_size)
34
  default_cell_emb = np.zeros(config.cell_embedding_size)
35
+
36
  # Convert the E3 ligase to Uniprot ID
37
+ if isinstance(e3_ligase, list):
38
+ e3_ligase_uniprot = [config.e3_ligase2uniprot.get(e3, '') for e3 in e3_ligase]
39
+ else:
40
+ e3_ligase_uniprot = config.e3_ligase2uniprot.get(e3_ligase, '')
41
 
42
  # Get the embeddings
43
+ if isinstance(protac_smiles, list):
44
+ # TODO: Add warning on missing entries?
45
+ poi_emb = [protein2embedding.get(t, default_protein_emb) for t in target_uniprot]
46
+ e3_emb = [protein2embedding.get(e3, default_protein_emb) for e3 in e3_ligase_uniprot]
47
+ cell_emb = [cell2embedding.get(cell_line, default_cell_emb) for cell_line in cell_line]
48
+ smiles_emb = [get_fingerprint(protac_smiles) for protac_smiles in protac_smiles]
49
+ else:
50
+ if e3_ligase not in config.e3_ligase2uniprot:
51
+ available_e3_ligases = ', '.join(list(config.e3_ligase2uniprot.keys()))
52
+ logging.warning(f"The E3 ligase {e3_ligase} is not in the database. Using the default E3 ligase. Available E3 ligases are: {available_e3_ligases}")
53
+ if target_uniprot not in protein2embedding:
54
+ logging.warning(f"The target protein {target_uniprot} is not in the database. Using the default target protein.")
55
+ if cell_line not in cell2embedding:
56
+ logging.warning(f"The cell line {cell_line} is not in the database. Using the default cell line.")
57
+ poi_emb = [protein2embedding.get(target_uniprot, default_protein_emb)]
58
+ e3_emb = [protein2embedding.get(e3_ligase_uniprot, default_protein_emb)]
59
+ cell_emb = [cell2embedding.get(cell_line, default_cell_emb)]
60
+ smiles_emb = [get_fingerprint(protac_smiles)]
61
 
62
  # Convert to torch tensors
63
  poi_emb = torch.tensor(poi_emb).to(device)
 
65
  cell_emb = torch.tensor(cell_emb).to(device)
66
  smiles_emb = torch.tensor(smiles_emb).to(device)
67
 
68
+ pred = model(
69
+ poi_emb,
70
+ e3_emb,
71
+ cell_emb,
72
+ smiles_emb,
73
+ prescaled_embeddings=False, # Trigger automatic scaling
74
+ )
75
+
76
+ if isinstance(protac_smiles, list):
77
+ return sigmoid(pred).detach().numpy().flatten()
78
+ else:
79
+ return sigmoid(pred).item()
80
 
81
 
82
  def is_protac_active(
 
107
  cell_line,
108
  device,
109
  )
110
+ return pred > proba_threshold
protac_degradation_predictor/pytorch_models.py CHANGED
@@ -23,6 +23,7 @@ from torchmetrics import (
23
  MetricCollection,
24
  )
25
  from imblearn.over_sampling import SMOTE
 
26
 
27
 
28
  class PROTAC_Predictor(nn.Module):
@@ -239,8 +240,46 @@ class PROTAC_Model(pl.LightningModule):
239
  self.val_dataset.apply_scaling(self.scalers, use_single_scaler)
240
  if self.test_dataset:
241
  self.test_dataset.apply_scaling(self.scalers, use_single_scaler)
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
- def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
245
 
246
  def step(self, batch, batch_idx, stage):
@@ -451,7 +490,7 @@ def train_model(
451
  ),
452
  pl.callbacks.EarlyStopping(
453
  monitor='val_loss',
454
- patience=5,
455
  mode='min',
456
  verbose=False,
457
  ),
 
23
  MetricCollection,
24
  )
25
  from imblearn.over_sampling import SMOTE
26
+ from sklearn.preprocessing import StandardScaler
27
 
28
 
29
  class PROTAC_Predictor(nn.Module):
 
240
  self.val_dataset.apply_scaling(self.scalers, use_single_scaler)
241
  if self.test_dataset:
242
  self.test_dataset.apply_scaling(self.scalers, use_single_scaler)
243
+
244
+ def scale_tensor(
245
+ self,
246
+ tensor: torch.Tensor,
247
+ scaler: StandardScaler,
248
+ ) -> torch.Tensor:
249
+ """Scale a tensor using a scaler. This is done to avoid using numpy
250
+ arrays (and stay on the same device).
251
+
252
+ Args:
253
+ tensor (torch.Tensor): The tensor to scale.
254
+ scaler (StandardScaler): The scaler to use.
255
 
256
+ Returns:
257
+ torch.Tensor: The scaled tensor.
258
+ """
259
+ tensor = tensor.float()
260
+ if scaler.with_mean:
261
+ tensor -= torch.tensor(scaler.mean_, dtype=tensor.dtype, device=tensor.device)
262
+ if scaler.with_std:
263
+ tensor /= torch.tensor(scaler.scale_, dtype=tensor.dtype, device=tensor.device)
264
+ return tensor
265
+
266
+ def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=True):
267
+ if not prescaled_embeddings:
268
+ if self.apply_scaling:
269
+ if self.join_embeddings == 'beginning':
270
+ embeddings = self.scale_tensor(
271
+ torch.hstack([smiles_emb, poi_emb, e3_emb, cell_emb]),
272
+ self.scalers,
273
+ )
274
+ smiles_emb = embeddings[:, :self.smiles_emb_dim]
275
+ poi_emb = embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.poi_emb_dim]
276
+ e3_emb = embeddings[:, self.smiles_emb_dim+self.poi_emb_dim:self.smiles_emb_dim+2*self.poi_emb_dim]
277
+ cell_emb = embeddings[:, -self.cell_emb_dim:]
278
+ else:
279
+ poi_emb = self.scale_tensor(poi_emb, self.scalers['Uniprot'])
280
+ e3_emb = self.scale_tensor(e3_emb, self.scalers['E3 Ligase Uniprot'])
281
+ cell_emb = self.scale_tensor(cell_emb, self.scalers['Cell Line Identifier'])
282
+ smiles_emb = self.scale_tensor(smiles_emb, self.scalers['Smiles'])
283
  return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
284
 
285
  def step(self, batch, batch_idx, stage):
 
490
  ),
491
  pl.callbacks.EarlyStopping(
492
  monitor='val_loss',
493
+ patience=10, # Original: 5
494
  mode='min',
495
  verbose=False,
496
  ),
setup.py CHANGED
@@ -17,5 +17,5 @@ setuptools.setup(
17
  "Operating System :: OS Independent",
18
  ],
19
  include_package_data=True,
20
- package_data={"": ["data/*.h5", "data/*.pkl", "data/*.csv"]},
21
  )
 
17
  "Operating System :: OS Independent",
18
  ],
19
  include_package_data=True,
20
+ package_data={"": ["data/*.h5", "data/*.pkl", "data/*.csv", "models/*.ckpt"]},
21
  )
tests/test_degradation_prediction.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import os
3
+ import sys
4
+ import logging
5
+
6
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7
+
8
+ from protac_degradation_predictor import (
9
+ get_protac_active_proba,
10
+ is_protac_active,
11
+ )
12
+
13
+ import torch
14
+
15
+
16
+ def test_active_proba():
17
+ protac_smiles = 'Cc1ncsc1-c1ccc([C@H](C)NC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](NC(=O)CC(=O)N2CCN(CC[C@H](CSc3ccccc3)Nc3ccc(S(=O)(=O)NC(=O)c4ccc(N5CCN(CC6=C(c7ccc(Cl)cc7)CCC(C)(C)C6)CC5)cc4)cc3S(=O)(=O)C(F)(F)F)CC2)C(C)(C)C)cc1'
18
+ e3_ligase = 'VHL'
19
+ target_uniprot = 'Q07817'
20
+ cell_line = 'MOLT-4'
21
+ device = 'cpu'
22
+
23
+ active_prob = get_protac_active_proba(
24
+ protac_smiles=protac_smiles,
25
+ e3_ligase=e3_ligase,
26
+ target_uniprot=target_uniprot,
27
+ cell_line=cell_line,
28
+ device=device,
29
+ )
30
+
31
+ print(f'Active probability: {active_prob} (CPU)')
32
+
33
+ active_prob = get_protac_active_proba(
34
+ protac_smiles=[protac_smiles] * 16,
35
+ e3_ligase=[e3_ligase] * 16,
36
+ target_uniprot=[target_uniprot] * 16,
37
+ cell_line=[cell_line] * 16,
38
+ device='gpu' if torch.cuda.is_available() else 'cpu',
39
+ )
40
+
41
+ print(f'Active probability: {active_prob} (GPU)')
42
+
43
+
44
+ def test_is_protac_active():
45
+ pass
tests/test_pytorch_model.py CHANGED
@@ -40,19 +40,21 @@ def test_protac_predictor():
40
  def test_load_model(caplog):
41
  caplog.set_level(logging.WARNING)
42
 
 
 
43
  model = PROTAC_Model.load_from_checkpoint(
44
- 'data/test_model.ckpt',
45
  map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
46
  )
47
  # apply_scaling: true
48
  # batch_size: 8
49
  # cell_emb_dim: 768
50
  # disabled_embeddings: []
51
- # dropout: 0.1498104322091649
52
  # e3_emb_dim: 1024
53
  # hidden_dim: 768
54
  # join_embeddings: concat
55
- # learning_rate: 4.881387978425994e-05
56
  # poi_emb_dim: 1024
57
  # smiles_emb_dim: 224
58
  assert model.hidden_dim == 768
@@ -61,20 +63,25 @@ def test_load_model(caplog):
61
  assert model.e3_emb_dim == 1024
62
  assert model.cell_emb_dim == 768
63
  assert model.batch_size == 8
64
- assert model.learning_rate == 4.881387978425994e-05
65
- assert model.dropout == 0.1498104322091649
66
  assert model.join_embeddings == 'concat'
67
  assert model.disabled_embeddings == []
68
  assert model.apply_scaling == True
 
69
 
70
 
71
  def test_checkpoint_file():
 
72
  checkpoint = torch.load(
73
- 'data/test_model.ckpt',
74
  map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
75
  )
76
  print(checkpoint.keys())
77
  print(checkpoint["hyper_parameters"])
78
  print([k for k, v in checkpoint["state_dict"].items()])
 
 
 
79
 
80
  pytest.main()
 
40
  def test_load_model(caplog):
41
  caplog.set_level(logging.WARNING)
42
 
43
+ model_filename = 'data/best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt'
44
+
45
  model = PROTAC_Model.load_from_checkpoint(
46
+ model_filename,
47
  map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
48
  )
49
  # apply_scaling: true
50
  # batch_size: 8
51
  # cell_emb_dim: 768
52
  # disabled_embeddings: []
53
+ # dropout: 0.11257777663560328
54
  # e3_emb_dim: 1024
55
  # hidden_dim: 768
56
  # join_embeddings: concat
57
+ # learning_rate: 1.843233973932415e-05
58
  # poi_emb_dim: 1024
59
  # smiles_emb_dim: 224
60
  assert model.hidden_dim == 768
 
63
  assert model.e3_emb_dim == 1024
64
  assert model.cell_emb_dim == 768
65
  assert model.batch_size == 8
66
+ assert model.learning_rate == 1.843233973932415e-05
67
+ assert model.dropout == 0.11257777663560328
68
  assert model.join_embeddings == 'concat'
69
  assert model.disabled_embeddings == []
70
  assert model.apply_scaling == True
71
+ print(model.scalers)
72
 
73
 
74
  def test_checkpoint_file():
75
+ model_filename = 'data/best_model_n0_random-epoch=6-val_acc=0.74-val_roc_auc=0.796.ckpt'
76
  checkpoint = torch.load(
77
+ model_filename,
78
  map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
79
  )
80
  print(checkpoint.keys())
81
  print(checkpoint["hyper_parameters"])
82
  print([k for k, v in checkpoint["state_dict"].items()])
83
+ import pickle
84
+
85
+ print(pickle.loads(checkpoint['scalers']))
86
 
87
  pytest.main()