ribesstefano commited on
Commit
fec8df0
·
1 Parent(s): 0da7f92

Added PROTAC embedding extraction

Browse files
protac_degradation_predictor/__init__.py CHANGED
@@ -28,6 +28,7 @@ from .optuna_utils_xgboost import (
28
  from .protac_degradation_predictor import (
29
  get_protac_active_proba,
30
  is_protac_active,
 
31
  )
32
 
33
  __version__ = "0.0.1"
 
28
  from .protac_degradation_predictor import (
29
  get_protac_active_proba,
30
  is_protac_active,
31
+ get_protac_embedding,
32
  )
33
 
34
  __version__ = "0.0.1"
protac_degradation_predictor/protac_degradation_predictor.py CHANGED
@@ -187,6 +187,8 @@ def get_protac_active_proba(
187
  # Average the predictions of all models
188
  preds = {}
189
  for ckpt_path, model in models.items():
 
 
190
  if not use_xgboost_models:
191
  pred = model(
192
  poi_emb if 'aminoacidcnt' not in ckpt_path else poi_aacnt,
@@ -198,7 +200,6 @@ def get_protac_active_proba(
198
  preds[ckpt_path] = sigmoid(pred).detach().cpu().numpy().flatten()
199
  else:
200
  X = np.hstack([smiles_emb, poi_emb, e3_emb, cell_emb])
201
- # pred = model.inplace_predict(X, (model.best_iteration, model.best_iteration))
202
  pred = model.inplace_predict(X)
203
  preds[ckpt_path] = pred
204
 
@@ -257,4 +258,177 @@ def is_protac_active(
257
  if use_majority_vote:
258
  return pred['majority_vote']
259
  else:
260
- return pred['mean'] > proba_threshold
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  # Average the predictions of all models
188
  preds = {}
189
  for ckpt_path, model in models.items():
190
+ # Get the last part of the path
191
+ ckpt_path = os.path.basename(ckpt_path)
192
  if not use_xgboost_models:
193
  pred = model(
194
  poi_emb if 'aminoacidcnt' not in ckpt_path else poi_aacnt,
 
200
  preds[ckpt_path] = sigmoid(pred).detach().cpu().numpy().flatten()
201
  else:
202
  X = np.hstack([smiles_emb, poi_emb, e3_emb, cell_emb])
 
203
  pred = model.inplace_predict(X)
204
  preds[ckpt_path] = pred
205
 
 
258
  if use_majority_vote:
259
  return pred['majority_vote']
260
  else:
261
+ return pred['mean'] > proba_threshold
262
+
263
+
264
+ def get_protac_embedding(
265
+ protac_smiles: str | List[str],
266
+ e3_ligase: str | List[str],
267
+ target_uniprot: str | List[str],
268
+ cell_line: str | List[str],
269
+ device: Literal['cpu', 'cuda'] = 'cpu',
270
+ use_models_from_cv: bool = False,
271
+ study_type: Literal['standard', 'similarity', 'target'] = 'standard',
272
+ ) -> Dict[str, np.ndarray]:
273
+ """ Get the embeddings of a PROTAC or a list of PROTACs.
274
+
275
+ Args:
276
+ protac_smiles (str | List[str]): The SMILES of the PROTAC.
277
+ e3_ligase (str | List[str]): The Uniprot ID of the E3 ligase.
278
+ target_uniprot (str | List[str]): The Uniprot ID of the target protein.
279
+ cell_line (str | List[str]): The cell line identifier.
280
+ device (str): The device to run the model on.
281
+ use_models_from_cv (bool): Whether to use the models from cross-validation.
282
+ study_type (str): Use models trained on the specified study. Options are 'standard', 'similarity', 'target'.
283
+
284
+ Returns:
285
+ Dict[str, np.ndarray]: The embeddings of the given PROTAC. Each key is the name of the model and the value is the embedding, of shape: (batch_size, model_hidden_size). NOTE: Each model has its own hidden size, so the embeddings might have different dimensions.
286
+ """
287
+ # Check that the study type is valid
288
+ if study_type not in ['standard', 'similarity', 'target']:
289
+ raise ValueError(f"Invalid study type: {study_type}. Options are 'standard', 'similarity', 'target'.")
290
+
291
+ # Check that the device is valid
292
+ if device not in ['cpu', 'cuda']:
293
+ raise ValueError(f"Invalid device: {device}. Options are 'cpu', 'cuda'.")
294
+
295
+ # Check that if any the models input is a list, all inputs are lists
296
+ model_inputs = [protac_smiles, e3_ligase, target_uniprot, cell_line]
297
+ if any(isinstance(i, list) for i in model_inputs):
298
+ if not all(isinstance(i, list) for i in model_inputs):
299
+ raise ValueError("All model inputs must be lists if one of the inputs is a list.")
300
+
301
+ # Load all required models in pkg_resources
302
+ device = torch.device(device)
303
+ models = {}
304
+ model_to_load = 'best_model' if not use_models_from_cv else 'cv_model'
305
+ for model_filename in pkg_resources.resource_listdir(__name__, 'models'):
306
+ if model_to_load not in model_filename:
307
+ continue
308
+ if study_type not in model_filename:
309
+ continue
310
+ if 'xgboost' not in model_filename:
311
+ ckpt_path = pkg_resources.resource_filename(__name__, f'models/{model_filename}')
312
+ models[ckpt_path] = load_model(ckpt_path).to(device)
313
+
314
+ protein2embedding = load_protein2embedding()
315
+ cell2embedding = load_cell2embedding()
316
+
317
+ # Get the dimension of the embeddings from the first np.array in the dictionary
318
+ protein_embedding_size = next(iter(protein2embedding.values())).shape[0]
319
+ cell_embedding_size = next(iter(cell2embedding.values())).shape[0]
320
+ # Setup default embeddings
321
+ default_protein_emb = np.zeros(protein_embedding_size)
322
+ default_cell_emb = np.zeros(cell_embedding_size)
323
+
324
+ # Check if any model name contains cellsonehot, if so, get onehot encoding
325
+ cell2onehot = None
326
+ if any('cellsonehot' in model_name for model_name in models.keys()):
327
+ onehotenc = OneHotEncoder(sparse_output=False)
328
+ cell_embeddings = onehotenc.fit_transform(
329
+ np.array(list(cell2embedding.keys())).reshape(-1, 1)
330
+ )
331
+ cell2onehot = {k: v for k, v in zip(cell2embedding.keys(), cell_embeddings)}
332
+
333
+ # Check if any of the model names contain aminoacidcnt, if so, get the CountVectorizer
334
+ protein2aacnt = None
335
+ if any('aminoacidcnt' in model_name for model_name in models.keys()):
336
+ # Create a new protein2embedding dictionary with amino acid sequence
337
+ protac_df = load_curated_dataset()
338
+ # Create the dictionary mapping 'Uniprot' to 'POI Sequence'
339
+ protein2aacnt = protac_df.set_index('Uniprot')['POI Sequence'].to_dict()
340
+ # Create the dictionary mapping 'E3 Ligase Uniprot' to 'E3 Ligase Sequence'
341
+ e32seq = protac_df.set_index('E3 Ligase Uniprot')['E3 Ligase Sequence'].to_dict()
342
+ # Merge the two dictionaries into a new protein2aacnt dictionary
343
+ protein2aacnt.update(e32seq)
344
+
345
+ # Get count vectorized embeddings for proteins
346
+ # NOTE: Check that the protein2aacnt is a dictionary of strings
347
+ if not all(isinstance(k, str) for k in protein2aacnt.keys()):
348
+ raise ValueError("All keys in `protein2aacnt` must be strings.")
349
+ countvec = CountVectorizer(ngram_range=(1, 1), analyzer='char')
350
+ protein_embeddings = countvec.fit_transform(
351
+ list(protein2aacnt.keys())
352
+ ).toarray()
353
+ protein2aacnt = {k: v for k, v in zip(protein2aacnt.keys(), protein_embeddings)}
354
+
355
+ # Convert the E3 ligase to Uniprot ID
356
+ if isinstance(e3_ligase, list):
357
+ e3_ligase_uniprot = [config.e3_ligase2uniprot.get(e3, '') for e3 in e3_ligase]
358
+ else:
359
+ e3_ligase_uniprot = config.e3_ligase2uniprot.get(e3_ligase, '')
360
+
361
+ # Get the embeddings for the PROTAC, E3 ligase, target protein, and cell line
362
+ # Check if the input is a list or a single string, in the latter case,
363
+ # convert to a list to create a batch of size 1, len(list) otherwise.
364
+ if isinstance(protac_smiles, list):
365
+ # TODO: Add warning on missing entries?
366
+ smiles_emb = [get_fingerprint(s) for s in protac_smiles]
367
+ cell_emb = [cell2embedding.get(c, default_cell_emb) for c in cell_line]
368
+ e3_emb = [protein2embedding.get(e3, default_protein_emb) for e3 in e3_ligase_uniprot]
369
+ poi_emb = [protein2embedding.get(t, default_protein_emb) for t in target_uniprot]
370
+ # Convert to one-hot encoded cell embeddings if necessary
371
+ if cell2onehot is not None:
372
+ cell_onehot = [cell2onehot.get(c, default_cell_emb) for c in cell_line]
373
+ # Convert to amino acid count embeddings if necessary
374
+ if protein2aacnt is not None:
375
+ poi_aacnt = [protein2aacnt.get(t, default_protein_emb) for t in target_uniprot]
376
+ e3_aacnt = [protein2aacnt.get(e3, default_protein_emb) for e3 in e3_ligase_uniprot]
377
+ else:
378
+ if e3_ligase not in config.e3_ligase2uniprot:
379
+ available_e3_ligases = ', '.join(list(config.e3_ligase2uniprot.keys()))
380
+ logging.warning(f"The E3 ligase {e3_ligase} is not in the database. Using the default E3 ligase. Available E3 ligases are: {available_e3_ligases}")
381
+ if target_uniprot not in protein2embedding:
382
+ logging.warning(f"The target protein {target_uniprot} is not in the database. Using the default target protein.")
383
+ if cell_line not in cell2embedding:
384
+ logging.warning(f"The cell line {cell_line} is not in the database. Using the default cell line.")
385
+ smiles_emb = [get_fingerprint(protac_smiles)]
386
+ cell_emb = [cell2embedding.get(cell_line, default_cell_emb)]
387
+ poi_emb = [protein2embedding.get(target_uniprot, default_protein_emb)]
388
+ e3_emb = [protein2embedding.get(e3_ligase_uniprot, default_protein_emb)]
389
+ # Convert to one-hot encoded cell embeddings if necessary
390
+ if cell2onehot is not None:
391
+ cell_onehot = [cell2onehot.get(cell_line, default_cell_emb)]
392
+ # Convert to amino acid count embeddings if necessary
393
+ if protein2aacnt is not None:
394
+ poi_aacnt = [protein2aacnt.get(target_uniprot, default_protein_emb)]
395
+ e3_aacnt = [protein2aacnt.get(e3_ligase_uniprot, default_protein_emb)]
396
+
397
+ # Convert to numpy arrays
398
+ smiles_emb = np.array(smiles_emb)
399
+ cell_emb = np.array(cell_emb)
400
+ poi_emb = np.array(poi_emb)
401
+ e3_emb = np.array(e3_emb)
402
+ if cell2onehot is not None:
403
+ cell_onehot = np.array(cell_onehot)
404
+ if protein2aacnt is not None:
405
+ poi_aacnt = np.array(poi_aacnt)
406
+ e3_aacnt = np.array(e3_aacnt)
407
+
408
+ # Convert to torch tensors
409
+ smiles_emb = torch.tensor(smiles_emb).float().to(device)
410
+ cell_emb = torch.tensor(cell_emb).to(device)
411
+ poi_emb = torch.tensor(poi_emb).to(device)
412
+ e3_emb = torch.tensor(e3_emb).to(device)
413
+ if cell2onehot is not None:
414
+ cell_onehot = torch.tensor(cell_onehot).float().to(device)
415
+ if protein2aacnt is not None:
416
+ poi_aacnt = torch.tensor(poi_aacnt).float().to(device)
417
+ e3_aacnt = torch.tensor(e3_aacnt).float().to(device)
418
+
419
+ # Average the predictions of all models
420
+ protac_embs = {}
421
+ for ckpt_path, model in models.items():
422
+ # Get the last part of the path
423
+ ckpt_path = os.path.basename(ckpt_path)
424
+ _, protac_emb = model(
425
+ poi_emb if 'aminoacidcnt' not in ckpt_path else poi_aacnt,
426
+ e3_emb if 'aminoacidcnt' not in ckpt_path else e3_aacnt,
427
+ cell_emb if 'cellsonehot' not in ckpt_path else cell_onehot,
428
+ smiles_emb,
429
+ prescaled_embeddings=False, # Normalization performed by the model
430
+ return_embeddings=True,
431
+ )
432
+ protac_embs[ckpt_path] = protac_emb
433
+
434
+ return protac_embs
protac_degradation_predictor/pytorch_models.py CHANGED
@@ -101,7 +101,7 @@ class PROTAC_Predictor(nn.Module):
101
  self.dropout = nn.Dropout(p=dropout)
102
 
103
 
104
- def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
105
  embeddings = []
106
  if self.join_embeddings == 'beginning':
107
  # TODO: Remove this if-branch
@@ -147,8 +147,10 @@ class PROTAC_Predictor(nn.Module):
147
  if torch.isnan(x).any():
148
  raise ValueError("NaN values found in sum of softmax-ed embeddings.")
149
  x = F.relu(self.fc1(x))
150
- x = self.bnorm(x) if self.use_batch_norm else self.self.dropout(x)
151
- x = self.fc3(x)
 
 
152
  return x
153
 
154
 
@@ -277,7 +279,7 @@ class PROTAC_Model(pl.LightningModule):
277
  tensor /= torch.tensor(scaler.scale_, dtype=tensor.dtype, device=tensor.device) + alpha
278
  return tensor
279
 
280
- def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=True):
281
  if not prescaled_embeddings:
282
  if self.apply_scaling:
283
  if self.join_embeddings == 'beginning':
@@ -302,7 +304,7 @@ class PROTAC_Model(pl.LightningModule):
302
  raise ValueError("NaN values found in cell embeddings.")
303
  if torch.isnan(smiles_emb).any():
304
  raise ValueError("NaN values found in SMILES embeddings.")
305
- return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
306
 
307
  def step(self, batch, batch_idx, stage):
308
  poi_emb = batch['poi_emb']
 
101
  self.dropout = nn.Dropout(p=dropout)
102
 
103
 
104
+ def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, return_embeddings=False):
105
  embeddings = []
106
  if self.join_embeddings == 'beginning':
107
  # TODO: Remove this if-branch
 
147
  if torch.isnan(x).any():
148
  raise ValueError("NaN values found in sum of softmax-ed embeddings.")
149
  x = F.relu(self.fc1(x))
150
+ h = self.bnorm(x) if self.use_batch_norm else self.self.dropout(x)
151
+ x = self.fc3(h)
152
+ if return_embeddings:
153
+ return x, h
154
  return x
155
 
156
 
 
279
  tensor /= torch.tensor(scaler.scale_, dtype=tensor.dtype, device=tensor.device) + alpha
280
  return tensor
281
 
282
+ def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=True, return_embeddings=False):
283
  if not prescaled_embeddings:
284
  if self.apply_scaling:
285
  if self.join_embeddings == 'beginning':
 
304
  raise ValueError("NaN values found in cell embeddings.")
305
  if torch.isnan(smiles_emb).any():
306
  raise ValueError("NaN values found in SMILES embeddings.")
307
+ return self.model(poi_emb, e3_emb, cell_emb, smiles_emb, return_embeddings)
308
 
309
  def step(self, batch, batch_idx, stage):
310
  poi_emb = batch['poi_emb']