ribesstefano commited on
Commit
ea572f9
·
1 Parent(s): 5e01175

started testing package code

Browse files
protac_degradation_predictor/__init__.py CHANGED
@@ -1,6 +1,14 @@
1
- from .protac_degradation_predictor import (
2
- PROTAC_Model,
3
- train_model,
 
 
 
 
 
 
 
 
4
  )
5
 
6
  __version__ = "0.0.1"
 
1
+ # from .protac_degradation_predictor.config import config
2
+ # from .protac_degradation_predictor.pytorch_models import train_model
3
+ # from .protac_degradation_predictor.pytorch_models import
4
+ # from .protac_degradation_predictor.pytorch_models import
5
+ from . import (
6
+ config,
7
+ pytorch_models,
8
+ sklearn_models,
9
+ protac_dataset,
10
+ data_utils,
11
+ optuna_utils,
12
  )
13
 
14
  __version__ = "0.0.1"
protac_degradation_predictor/config.py CHANGED
@@ -1,4 +1,4 @@
1
- from dataclasses import dataclass
2
 
3
  @dataclass(frozen=True)
4
  class Config:
@@ -11,27 +11,24 @@ class Config:
11
  # Data information
12
  dmax_threshold: float = 0.6
13
  pdc50_threshold: float = 6.0
14
- e3_ligase2uniprot: dict = {
15
- 'VHL': 'P40337',
16
- 'CRBN': 'Q96SW2',
17
- 'DCAF11': 'Q8TEB1',
18
- 'DCAF15': 'Q66K64',
19
- 'DCAF16': 'Q9NXF7',
20
- 'MDM2': 'Q00987',
21
- 'Mdm2': 'Q00987',
22
- 'XIAP': 'P98170',
23
- 'cIAP1': 'Q7Z460',
24
- 'IAP': 'P98170', # I couldn't find the Uniprot ID for IAP, so it's XIAP instead
25
- 'Iap': 'P98170', # I couldn't find the Uniprot ID for IAP, so it's XIAP instead
26
- 'AhR': 'P35869',
27
- 'RNF4': 'P78317',
28
- 'RNF114': 'Q9Y508',
29
- 'FEM1B': 'Q9UK73',
30
- 'Ubr1': 'Q8IWV7',
31
- }
32
-
33
- def __post_init__(self):
34
- self.active_label: str = f'Active (Dmax {self.dmax_threshold}, pDC50 {self.pdc50_threshold})'
35
-
36
 
37
  config = Config()
 
1
+ from dataclasses import dataclass, field
2
 
3
  @dataclass(frozen=True)
4
  class Config:
 
11
  # Data information
12
  dmax_threshold: float = 0.6
13
  pdc50_threshold: float = 6.0
14
+ active_label: str = field(default=f'Active (Dmax {dmax_threshold}, pDC50 {pdc50_threshold})')
15
+ e3_ligase2uniprot: dict = field(default_factory=lambda: {
16
+ 'VHL': 'P40337',
17
+ 'CRBN': 'Q96SW2',
18
+ 'DCAF11': 'Q8TEB1',
19
+ 'DCAF15': 'Q66K64',
20
+ 'DCAF16': 'Q9NXF7',
21
+ 'MDM2': 'Q00987',
22
+ 'Mdm2': 'Q00987',
23
+ 'XIAP': 'P98170',
24
+ 'cIAP1': 'Q7Z460',
25
+ 'IAP': 'P98170', # I couldn't find the Uniprot ID for IAP, so it's XIAP instead
26
+ 'Iap': 'P98170', # I couldn't find the Uniprot ID for IAP, so it's XIAP instead
27
+ 'AhR': 'P35869',
28
+ 'RNF4': 'P78317',
29
+ 'RNF114': 'Q9Y508',
30
+ 'FEM1B': 'Q9UK73',
31
+ 'Ubr1': 'Q8IWV7',
32
+ })
 
 
 
33
 
34
  config = Config()
protac_degradation_predictor/data/PROTAC-DB.csv DELETED
The diff for this file is too large to render. See raw diff
 
protac_degradation_predictor/data/PROTAC-Degradation-DB.csv ADDED
The diff for this file is too large to render. See raw diff
 
protac_degradation_predictor/data_utils.py CHANGED
@@ -1,9 +1,9 @@
1
  import os
2
  import pkg_resources
3
  import pickle
4
- from typing import Dict
5
 
6
- from config import config
7
 
8
  import h5py
9
  import numpy as np
@@ -19,8 +19,19 @@ memory = Memory(cachedir, verbose=0)
19
 
20
 
21
  @memory.cache
22
- def load_protein2embedding() -> Dict[str, np.ndarray]:
23
- embeddings_path = pkg_resources.resource_stream(__name__, 'data/uniprot2embedding.h5')
 
 
 
 
 
 
 
 
 
 
 
24
  protein2embedding = {}
25
  with h5py.File(embeddings_path, "r") as file:
26
  for sequence_id in file.keys():
@@ -30,17 +41,74 @@ def load_protein2embedding() -> Dict[str, np.ndarray]:
30
 
31
 
32
  @memory.cache
33
- def load_cell2embedding() -> Dict[str, np.ndarray]:
34
- embeddings_path = pkg_resources.resource_stream(__name__, 'data/cell2embedding.pkl')
 
 
 
 
 
 
 
 
 
 
 
35
  with open(embeddings_path, 'rb') as f:
36
  cell2embedding = pickle.load(f)
37
  return cell2embedding
38
 
39
 
40
- def get_fingerprint(smiles: str) -> np.ndarray:
41
- morgan_fpgen = AllChem.GetMorganGenerator(
42
- radius=config.morgan_radius,
43
- fpSize=config.fingerprint_size,
44
- includeChirality=True,
45
- )
46
- return morgan_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import pkg_resources
3
  import pickle
4
+ from typing import Dict, Optional
5
 
6
+ from .config import config
7
 
8
  import h5py
9
  import numpy as np
 
19
 
20
 
21
  @memory.cache
22
+ def load_protein2embedding(
23
+ embeddings_path: Optional[str] = None,
24
+ ) -> Dict[str, np.ndarray]:
25
+ """ Load the protein embeddings from a file.
26
+
27
+ Args:
28
+ embeddings_path (str): The path to the embeddings file.
29
+
30
+ Returns:
31
+ Dict[str, np.ndarray]: A dictionary of protein embeddings.
32
+ """
33
+ if embeddings_path is None:
34
+ embeddings_path = pkg_resources.resource_stream(__name__, 'data/uniprot2embedding.h5')
35
  protein2embedding = {}
36
  with h5py.File(embeddings_path, "r") as file:
37
  for sequence_id in file.keys():
 
41
 
42
 
43
  @memory.cache
44
+ def load_cell2embedding(
45
+ embeddings_path: Optional[str] = None,
46
+ ) -> Dict[str, np.ndarray]:
47
+ """ Load the cell line embeddings from a file.
48
+
49
+ Args:
50
+ embeddings_path (str): The path to the embeddings file.
51
+
52
+ Returns:
53
+ Dict[str, np.ndarray]: A dictionary of cell line embeddings.
54
+ """
55
+ if embeddings_path is None:
56
+ embeddings_path = pkg_resources.resource_stream(__name__, 'data/cell2embedding.pkl')
57
  with open(embeddings_path, 'rb') as f:
58
  cell2embedding = pickle.load(f)
59
  return cell2embedding
60
 
61
 
62
+ def get_fingerprint(smiles: str, morgan_fpgen = None) -> np.ndarray:
63
+ """ Get the Morgan fingerprint of a molecule.
64
+
65
+ Args:
66
+ smiles (str): The SMILES string of the molecule.
67
+ morgan_fpgen: The Morgan fingerprint generator.
68
+
69
+ Returns:
70
+ np.ndarray: The Morgan fingerprint.
71
+ """
72
+ if morgan_fpgen is None:
73
+ morgan_fpgen = AllChem.GetMorganGenerator(
74
+ radius=config.morgan_radius,
75
+ fpSize=config.fingerprint_size,
76
+ includeChirality=True,
77
+ )
78
+ return morgan_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
79
+
80
+
81
+ def is_active(
82
+ DC50: float,
83
+ Dmax: float,
84
+ pDC50_threshold: float = 7.0,
85
+ Dmax_threshold: float = 0.8,
86
+ oring: bool = False, # Deprecated
87
+ ) -> bool:
88
+ """ Check if a PROTAC is active based on DC50 and Dmax.
89
+ Args:
90
+ DC50(float): DC50 in nM
91
+ Dmax(float): Dmax in %
92
+ Returns:
93
+ bool: True if active, False if inactive, np.nan if either DC50 or Dmax is NaN
94
+ """
95
+ pDC50 = -np.log10(DC50 * 1e-9) if pd.notnull(DC50) else np.nan
96
+ Dmax = Dmax / 100
97
+ if pd.notnull(pDC50):
98
+ if pDC50 < pDC50_threshold:
99
+ return False
100
+ if pd.notnull(Dmax):
101
+ if Dmax < Dmax_threshold:
102
+ return False
103
+ if oring:
104
+ if pd.notnull(pDC50):
105
+ return True if pDC50 >= pDC50_threshold else False
106
+ elif pd.notnull(Dmax):
107
+ return True if Dmax >= Dmax_threshold else False
108
+ else:
109
+ return np.nan
110
+ else:
111
+ if pd.notnull(pDC50) and pd.notnull(Dmax):
112
+ return True if pDC50 >= pDC50_threshold and Dmax >= Dmax_threshold else False
113
+ else:
114
+ return np.nan
protac_degradation_predictor/optuna_utils.py CHANGED
@@ -1,8 +1,8 @@
1
  import os
2
  from typing import Literal, List, Tuple, Optional, Dict
3
 
4
- from pytorch_models import train_model
5
- from sklearn_models import (
6
  train_sklearn_model,
7
  suggest_random_forest,
8
  suggest_logistic_regression,
 
1
  import os
2
  from typing import Literal, List, Tuple, Optional, Dict
3
 
4
+ from .pytorch_models import train_model
5
+ from .sklearn_models import (
6
  train_sklearn_model,
7
  suggest_random_forest,
8
  suggest_logistic_regression,
protac_degradation_predictor/protac_degradation_predictor.py CHANGED
@@ -1,20 +1,22 @@
1
  import pkg_resources
2
  import logging
3
 
4
- from pytorch_models import PROTAC_Model, load_model
5
- from data_utils import (
6
  load_protein2embedding,
7
  load_cell2embedding,
8
  get_fingerprint,
9
  )
10
- from config import config
11
 
12
  import numpy as np
13
  import torch
14
  from torch import sigmoid
15
 
 
16
  package_name = 'protac_degradation_predictor'
17
 
 
18
  def get_protac_active_proba(
19
  protac_smiles: str,
20
  e3_ligase: str,
 
1
  import pkg_resources
2
  import logging
3
 
4
+ from .pytorch_models import PROTAC_Model, load_model
5
+ from .data_utils import (
6
  load_protein2embedding,
7
  load_cell2embedding,
8
  get_fingerprint,
9
  )
10
+ from .config import config
11
 
12
  import numpy as np
13
  import torch
14
  from torch import sigmoid
15
 
16
+
17
  package_name = 'protac_degradation_predictor'
18
 
19
+
20
  def get_protac_active_proba(
21
  protac_smiles: str,
22
  e3_ligase: str,
protac_degradation_predictor/pytorch_models.py CHANGED
@@ -1,8 +1,8 @@
1
  import warnings
2
  from typing import Literal, List, Tuple, Optional, Dict
3
 
4
- from protac_dataset import PROTAC_Dataset
5
- from config import Config
6
 
7
  import pandas as pd
8
  import numpy as np
 
1
  import warnings
2
  from typing import Literal, List, Tuple, Optional, Dict
3
 
4
+ from .protac_dataset import PROTAC_Dataset
5
+ from .config import Config
6
 
7
  import pandas as pd
8
  import numpy as np
protac_degradation_predictor/sklearn_models.py CHANGED
@@ -1,6 +1,6 @@
1
  from typing import Literal, List, Tuple, Optional, Dict
2
 
3
- from protac_dataset import PROTAC_Dataset
4
 
5
  import pandas as pd
6
  from sklearn.base import ClassifierMixin
 
1
  from typing import Literal, List, Tuple, Optional, Dict
2
 
3
+ from .protac_dataset import PROTAC_Dataset
4
 
5
  import pandas as pd
6
  from sklearn.base import ClassifierMixin
src/main.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import defaultdict
3
+ import warnings
4
+
5
+ from protac_degradation_predictor.config import config
6
+ from protac_degradation_predictor.data_utils import (
7
+ load_protein2embedding,
8
+ load_cell2embedding,
9
+ is_active,
10
+ )
11
+ from protac_degradation_predictor.pytorch_models import (
12
+ train_model,
13
+ )
14
+ from protac_degradation_predictor.optuna_utils import (
15
+ hyperparameter_tuning_and_training,
16
+ )
17
+
18
+ from rdkit import Chem
19
+ from rdkit.Chem import AllChem
20
+ from rdkit import DataStructs
21
+ from jsonargparse import CLI
22
+ import pandas as pd
23
+ from tqdm import tqdm
24
+ import numpy as np
25
+ from sklearn.preprocessing import OrdinalEncoder
26
+ from sklearn.model_selection import (
27
+ StratifiedKFold,
28
+ StratifiedGroupKFold,
29
+ )
30
+
31
+
32
+ # Ignore UserWarning from Matplotlib
33
+ warnings.filterwarnings("ignore", ".*FixedLocator*")
34
+ # Ignore UserWarning from PyTorch Lightning
35
+ warnings.filterwarnings("ignore", ".*does not have many workers.*")
36
+
37
+
38
+ def main(
39
+ active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
40
+ n_trials: int = 50,
41
+ fast_dev_run: bool = False,
42
+ test_split: float = 0.2,
43
+ cv_n_splits: int = 5,
44
+ ):
45
+ """ Train a PROTAC model using the given datasets and hyperparameters.
46
+
47
+ Args:
48
+ use_ored_activity (bool): Whether to use the 'Active - OR' column.
49
+ n_trials (int): The number of hyperparameter optimization trials.
50
+ n_splits (int): The number of cross-validation splits.
51
+ fast_dev_run (bool): Whether to run a fast development run.
52
+ """
53
+ ## Set the Column to Predict
54
+ active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
55
+
56
+ # Get Dmax_threshold from the active_col
57
+ Dmax_threshold = float(active_col.split('Dmax')[1].split(',')[0].strip('(').strip(')').strip())
58
+ pDC50_threshold = float(active_col.split('pDC50')[1].strip('(').strip(')').strip())
59
+
60
+ ## Load the Data
61
+ protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
62
+
63
+ # Map E3 Ligase Iap to IAP
64
+ protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
65
+
66
+ protac_df[active_col] = protac_df.apply(
67
+ lambda x: is_active(x['DC50 (nM)'], x['Dmax (%)'], pDC50_threshold=pDC50_threshold, Dmax_threshold=Dmax_threshold), axis=1
68
+ )
69
+
70
+ ## Test Sets
71
+
72
+ test_indeces = {}
73
+
74
+ ### Random Split
75
+
76
+ # Randomly select 20% of the active PROTACs as the test set
77
+ active_df = protac_df[protac_df[active_col].notna()].copy()
78
+ test_df = active_df.sample(frac=test_split, random_state=42)
79
+ test_indeces['random'] = test_df.index
80
+
81
+ ### E3-based Split
82
+
83
+ encoder = OrdinalEncoder()
84
+ protac_df['E3 Group'] = encoder.fit_transform(protac_df[['E3 Ligase']]).astype(int)
85
+ active_df = protac_df[protac_df[active_col].notna()].copy()
86
+ test_df = active_df[(active_df['E3 Ligase'] != 'VHL') & (active_df['E3 Ligase'] != 'CRBN')]
87
+ test_indeces['e3_ligase'] = test_df.index
88
+
89
+ ### Tanimoto-based Split
90
+
91
+ #### Precompute fingerprints
92
+ morgan_fpgen = AllChem.GetMorganGenerator(
93
+ radius=config.morgan_radius,
94
+ fpSize=config.fingerprint_size,
95
+ includeChirality=True,
96
+ )
97
+
98
+ smiles2fp = {}
99
+ for smiles in tqdm(protac_df['Smiles'].unique().tolist(), desc='Precomputing fingerprints'):
100
+ # Get the fingerprint as a bit vector
101
+ morgan_fp = morgan_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
102
+ smiles2fp[smiles] = morgan_fp
103
+
104
+ # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
105
+ tanimoto_matrix = defaultdict(list)
106
+ for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
107
+ fp1 = smiles2fp[smiles1]
108
+ # TODO: Use BulkTanimotoSimilarity for better performance
109
+ for j, smiles2 in enumerate(protac_df['Smiles'].unique()):
110
+ if j < i:
111
+ continue
112
+ fp2 = smiles2fp[smiles2]
113
+ tanimoto_dist = DataStructs.TanimotoSimilarity(fp1, fp2)
114
+ tanimoto_matrix[smiles1].append(tanimoto_dist)
115
+ avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
116
+ protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
117
+
118
+ smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
119
+
120
+ # Make the grouping of the PROTACs based on the Tanimoto similarity
121
+ n_bins_tanimoto = 200
122
+ tanimoto_groups = pd.cut(protac_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
123
+ encoder = OrdinalEncoder()
124
+ protac_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
125
+ active_df = protac_df[protac_df[active_col].notna()].copy()
126
+ # Sort the groups so that samples with the highest tanimoto similarity,
127
+ # i.e., the "less similar" ones, are placed in the test set first
128
+ tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
129
+
130
+ test_df = []
131
+ # For each group, get the number of active and inactive entries. Then, add those
132
+ # entries to the test_df if: 1) the test_df lenght + the group entries is less
133
+ # 20% of the active_df lenght, and 2) the percentage of True and False entries
134
+ # in the active_col in test_df is roughly 50%.
135
+ for group in tanimoto_groups:
136
+ group_df = active_df[active_df['Tanimoto Group'] == group]
137
+ if test_df == []:
138
+ test_df.append(group_df)
139
+ continue
140
+
141
+ num_entries = len(group_df)
142
+ num_active_group = group_df[active_col].sum()
143
+ num_inactive_group = num_entries - num_active_group
144
+
145
+ tmp_test_df = pd.concat(test_df)
146
+ num_entries_test = len(tmp_test_df)
147
+ num_active_test = tmp_test_df[active_col].sum()
148
+ num_inactive_test = num_entries_test - num_active_test
149
+
150
+ # Check if the group entries can be added to the test_df
151
+ if num_entries_test + num_entries < test_split * len(active_df):
152
+ # Add anything at the beggining
153
+ if num_entries_test + num_entries < test_split / 2 * len(active_df):
154
+ test_df.append(group_df)
155
+ continue
156
+ # Be more selective and make sure that the percentage of active and
157
+ # inactive is balanced
158
+ if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
159
+ if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
160
+ test_df.append(group_df)
161
+ test_df = pd.concat(test_df)
162
+ # Save to global dictionary of test indeces
163
+ test_indeces['tanimoto'] = test_df.index
164
+
165
+ ### Target-based Split
166
+
167
+ encoder = OrdinalEncoder()
168
+ protac_df['Uniprot Group'] = encoder.fit_transform(protac_df[['Uniprot']]).astype(int)
169
+ active_df = protac_df[protac_df[active_col].notna()].copy()
170
+
171
+ test_df = []
172
+ # For each group, get the number of active and inactive entries. Then, add those
173
+ # entries to the test_df if: 1) the test_df lenght + the group entries is less
174
+ # 20% of the active_df lenght, and 2) the percentage of True and False entries
175
+ # in the active_col in test_df is roughly 50%.
176
+ # Start the loop from the groups containing the smallest number of entries.
177
+ for group in reversed(active_df['Uniprot'].value_counts().index):
178
+ group_df = active_df[active_df['Uniprot'] == group]
179
+ if test_df == []:
180
+ test_df.append(group_df)
181
+ continue
182
+
183
+ num_entries = len(group_df)
184
+ num_active_group = group_df[active_col].sum()
185
+ num_inactive_group = num_entries - num_active_group
186
+
187
+ tmp_test_df = pd.concat(test_df)
188
+ num_entries_test = len(tmp_test_df)
189
+ num_active_test = tmp_test_df[active_col].sum()
190
+ num_inactive_test = num_entries_test - num_active_test
191
+
192
+ # Check if the group entries can be added to the test_df
193
+ if num_entries_test + num_entries < test_split * len(active_df):
194
+ # Add anything at the beggining
195
+ if num_entries_test + num_entries < test_split / 2 * len(active_df):
196
+ test_df.append(group_df)
197
+ continue
198
+ # Be more selective and make sure that the percentage of active and
199
+ # inactive is balanced
200
+ if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
201
+ if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
202
+ test_df.append(group_df)
203
+ test_df = pd.concat(test_df)
204
+ # Save to global dictionary of test indeces
205
+ test_indeces['uniprot'] = test_df.index
206
+
207
+ ## Cross-Validation Training
208
+
209
+ # Make directory ../reports if it does not exist
210
+ if not os.path.exists('../reports'):
211
+ os.makedirs('../reports')
212
+
213
+ # Load embedding dictionaries
214
+ protein2embedding = load_protein2embedding('../data/uniprot2embedding.h5')
215
+ cell2embedding = load_cell2embedding('../data/cell2embedding.pkl')
216
+
217
+ report = []
218
+ for split_type, indeces in test_indeces.items():
219
+ active_df = protac_df[protac_df[active_col].notna()].copy()
220
+ test_df = active_df.loc[indeces]
221
+ train_val_df = active_df[~active_df.index.isin(test_df.index)]
222
+
223
+ if split_type == 'random':
224
+ kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
225
+ group = None
226
+ elif split_type == 'e3_ligase':
227
+ kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
228
+ group = train_val_df['E3 Group'].to_numpy()
229
+ elif split_type == 'tanimoto':
230
+ kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
231
+ group = train_val_df['Tanimoto Group'].to_numpy()
232
+ elif split_type == 'uniprot':
233
+ kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
234
+ group = train_val_df['Uniprot Group'].to_numpy()
235
+ # Start the CV over the folds
236
+ X = train_val_df.drop(columns=active_col)
237
+ y = train_val_df[active_col].tolist()
238
+ for k, (train_index, val_index) in enumerate(kf.split(X, y, group)):
239
+ print('-' * 100)
240
+ print(f'Starting CV for group type: {split_type}, fold: {k}')
241
+ print('-' * 100)
242
+ train_df = train_val_df.iloc[train_index]
243
+ val_df = train_val_df.iloc[val_index]
244
+
245
+ leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(val_df['Uniprot'])))
246
+ leaking_smiles = list(set(train_df['Smiles']).intersection(set(val_df['Smiles'])))
247
+
248
+ stats = {
249
+ 'fold': k,
250
+ 'split_type': split_type,
251
+ 'train_len': len(train_df),
252
+ 'val_len': len(val_df),
253
+ 'train_perc': len(train_df) / len(train_val_df),
254
+ 'val_perc': len(val_df) / len(train_val_df),
255
+ 'train_active_perc': train_df[active_col].sum() / len(train_df),
256
+ 'train_inactive_perc': (len(train_df) - train_df[active_col].sum()) / len(train_df),
257
+ 'val_active_perc': val_df[active_col].sum() / len(val_df),
258
+ 'val_inactive_perc': (len(val_df) - val_df[active_col].sum()) / len(val_df),
259
+ 'test_active_perc': test_df[active_col].sum() / len(test_df),
260
+ 'test_inactive_perc': (len(test_df) - test_df[active_col].sum()) / len(test_df),
261
+ 'num_leaking_uniprot': len(leaking_uniprot),
262
+ 'num_leaking_smiles': len(leaking_smiles),
263
+ 'train_leaking_uniprot_perc': len(train_df[train_df['Uniprot'].isin(leaking_uniprot)]) / len(train_df),
264
+ 'train_leaking_smiles_perc': len(train_df[train_df['Smiles'].isin(leaking_smiles)]) / len(train_df),
265
+ }
266
+ if split_type != 'random':
267
+ stats['train_unique_groups'] = len(np.unique(group[train_index]))
268
+ stats['val_unique_groups'] = len(np.unique(group[val_index]))
269
+
270
+ print(stats)
271
+ # # Train and evaluate the model
272
+ # model, trainer, metrics = hyperparameter_tuning_and_training(
273
+ # protein2embedding,
274
+ # cell2embedding,
275
+ # smiles2fp,
276
+ # train_df,
277
+ # val_df,
278
+ # test_df,
279
+ # fast_dev_run=fast_dev_run,
280
+ # n_trials=n_trials,
281
+ # logger_name=f'protac_{active_name}_{split_type}_fold_{k}_test_split_{test_split}',
282
+ # active_label=active_col,
283
+ # study_filename=f'../reports/study_{active_name}_{split_type}_fold_{k}_test_split_{test_split}.pkl',
284
+ # )
285
+ # hparams = {p.replace('hparam_', ''): v for p, v in stats.items() if p.startswith('hparam_')}
286
+ # stats.update(metrics)
287
+ # report.append(stats.copy())
288
+ # del model
289
+ # del trainer
290
+
291
+ # # Ablation study: disable embeddings at a time
292
+ # for disabled_embeddings in [['e3'], ['poi'], ['cell'], ['smiles'], ['e3', 'cell'], ['poi', 'e3', 'cell']]:
293
+ # print('-' * 100)
294
+ # print(f'Ablation study with disabled embeddings: {disabled_embeddings}')
295
+ # print('-' * 100)
296
+ # stats['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
297
+ # model, trainer, metrics = train_model(
298
+ # protein2embedding,
299
+ # cell2embedding,
300
+ # smiles2fp,
301
+ # train_df,
302
+ # val_df,
303
+ # test_df,
304
+ # fast_dev_run=fast_dev_run,
305
+ # logger_name=f'protac_{active_name}_{split_type}_fold_{k}_disabled-{"-".join(disabled_embeddings)}',
306
+ # active_label=active_col,
307
+ # disabled_embeddings=disabled_embeddings,
308
+ # **hparams,
309
+ # )
310
+ # stats.update(metrics)
311
+ # report.append(stats.copy())
312
+ # del model
313
+ # del trainer
314
+
315
+ # report_df = pd.DataFrame(report)
316
+ # report_df.to_csv(
317
+ # f'../reports/cv_report_hparam_search_{cv_n_splits}-splits_{active_name}_test_split_{test_split}_sklearn.csv',
318
+ # index=False,
319
+ # )
320
+
321
+
322
+ if __name__ == '__main__':
323
+ cli = CLI(main)