Commit
·
82509b6
1
Parent(s):
cf7560f
Updated experiment scripts
Browse files- src/{run_experiments_aminoacid_counts.py → run_experiments_aminoacidcnt.py} +0 -0
- src/{run_experiments_cells_onehot.py → run_experiments_cellsonehot.py} +0 -0
- src/run_experiments_cellsonehot_aminoacidcnt.py +168 -0
- src/{run_experiments.py → run_experiments_pytorch.py} +3 -3
- src/run_experiments_xgboost.py +1 -1
src/{run_experiments_aminoacid_counts.py → run_experiments_aminoacidcnt.py}
RENAMED
File without changes
|
src/{run_experiments_cells_onehot.py → run_experiments_cellsonehot.py}
RENAMED
File without changes
|
src/run_experiments_cellsonehot_aminoacidcnt.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from collections import defaultdict
|
4 |
+
import warnings
|
5 |
+
import logging
|
6 |
+
from typing import Literal
|
7 |
+
|
8 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
9 |
+
|
10 |
+
import protac_degradation_predictor as pdp
|
11 |
+
|
12 |
+
import pytorch_lightning as pl
|
13 |
+
from rdkit import Chem
|
14 |
+
from rdkit.Chem import AllChem
|
15 |
+
from rdkit import DataStructs
|
16 |
+
from jsonargparse import CLI
|
17 |
+
import pandas as pd
|
18 |
+
from tqdm import tqdm
|
19 |
+
import numpy as np
|
20 |
+
from sklearn.preprocessing import OrdinalEncoder
|
21 |
+
from sklearn.model_selection import (
|
22 |
+
StratifiedKFold,
|
23 |
+
StratifiedGroupKFold,
|
24 |
+
)
|
25 |
+
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
|
26 |
+
from sklearn.feature_extraction.text import CountVectorizer
|
27 |
+
|
28 |
+
# Ignore UserWarning from Matplotlib
|
29 |
+
warnings.filterwarnings("ignore", ".*FixedLocator*")
|
30 |
+
# Ignore UserWarning from PyTorch Lightning
|
31 |
+
warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
32 |
+
|
33 |
+
|
34 |
+
root = logging.getLogger()
|
35 |
+
root.setLevel(logging.DEBUG)
|
36 |
+
|
37 |
+
handler = logging.StreamHandler(sys.stdout)
|
38 |
+
handler.setLevel(logging.DEBUG)
|
39 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
40 |
+
handler.setFormatter(formatter)
|
41 |
+
root.addHandler(handler)
|
42 |
+
|
43 |
+
def main(
|
44 |
+
active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
|
45 |
+
n_trials: int = 100,
|
46 |
+
fast_dev_run: bool = False,
|
47 |
+
test_split: float = 0.1,
|
48 |
+
cv_n_splits: int = 5,
|
49 |
+
max_epochs: int = 100,
|
50 |
+
force_study: bool = False,
|
51 |
+
experiments: str | Literal['all', 'standard', 'e3_ligase', 'similarity', 'target'] = 'all',
|
52 |
+
):
|
53 |
+
""" Run experiments with the cells one-hot encoding model.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
active_col (str): Name of the column containing the active values.
|
57 |
+
n_trials (int): Number of hyperparameter optimization trials.
|
58 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
59 |
+
test_split (float): Percentage of data to use for testing.
|
60 |
+
cv_n_splits (int): Number of cross-validation splits.
|
61 |
+
max_epochs (int): Maximum number of epochs to train the model.
|
62 |
+
force_study (bool): Whether to force the creation of a new study.
|
63 |
+
experiments (str): Type of experiments to run. Options are 'all', 'standard', 'e3_ligase', 'similarity', 'target'.
|
64 |
+
"""
|
65 |
+
pl.seed_everything(42)
|
66 |
+
|
67 |
+
# Make directory ../reports if it does not exist
|
68 |
+
if not os.path.exists('../reports'):
|
69 |
+
os.makedirs('../reports')
|
70 |
+
|
71 |
+
# Load embedding dictionaries
|
72 |
+
protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
|
73 |
+
cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
|
74 |
+
|
75 |
+
# Get one-hot encoded embeddings for cell lines
|
76 |
+
onehotenc = OneHotEncoder(sparse_output=False)
|
77 |
+
cell_embeddings = onehotenc.fit_transform(
|
78 |
+
np.array(list(cell2embedding.keys())).reshape(-1, 1)
|
79 |
+
)
|
80 |
+
cell2embedding = {k: v for k, v in zip(cell2embedding.keys(), cell_embeddings)}
|
81 |
+
|
82 |
+
# Create a new protein2embedding dictionary with amino acid sequence
|
83 |
+
protac_df = pdp.load_curated_dataset()
|
84 |
+
# Create the dictionary mapping 'Uniprot' to 'POI Sequence'
|
85 |
+
protein2embedding = protac_df.set_index('Uniprot')['POI Sequence'].to_dict()
|
86 |
+
# Create the dictionary mapping 'E3 Ligase Uniprot' to 'E3 Ligase Sequence'
|
87 |
+
e32seq = protac_df.set_index('E3 Ligase Uniprot')['E3 Ligase Sequence'].to_dict()
|
88 |
+
# Merge the two dictionaries into a new protein2embedding dictionary
|
89 |
+
protein2embedding.update(e32seq)
|
90 |
+
|
91 |
+
# Get count vectorized embeddings for proteins
|
92 |
+
# NOTE: Check that the protein2embedding is a dictionary of strings
|
93 |
+
if not all(isinstance(k, str) for k in protein2embedding.keys()):
|
94 |
+
raise ValueError("All keys in `protein2embedding` must be strings.")
|
95 |
+
countvec = CountVectorizer(ngram_range=(1, 1), analyzer='char')
|
96 |
+
protein_embeddings = countvec.fit_transform(
|
97 |
+
list(protein2embedding.keys())
|
98 |
+
).toarray()
|
99 |
+
protein2embedding = {k: v for k, v in zip(protein2embedding.keys(), protein_embeddings)}
|
100 |
+
|
101 |
+
studies_dir = '../data/studies'
|
102 |
+
train_val_perc = f'{int((1 - test_split) * 100)}'
|
103 |
+
test_perc = f'{int(test_split * 100)}'
|
104 |
+
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
|
105 |
+
|
106 |
+
if experiments == 'all':
|
107 |
+
experiments = ['standard', 'similarity', 'target']
|
108 |
+
else:
|
109 |
+
experiments = [experiments]
|
110 |
+
|
111 |
+
# Cross-Validation Training
|
112 |
+
reports = defaultdict(list)
|
113 |
+
for split_type in experiments:
|
114 |
+
|
115 |
+
train_val_filename = f'{split_type}_train_val_{train_val_perc}split_{active_name}.csv'
|
116 |
+
test_filename = f'{split_type}_test_{test_perc}split_{active_name}.csv'
|
117 |
+
|
118 |
+
train_val_df = pd.read_csv(os.path.join(studies_dir, train_val_filename))
|
119 |
+
test_df = pd.read_csv(os.path.join(studies_dir, test_filename))
|
120 |
+
|
121 |
+
# Get SMILES and precompute fingerprints dictionary
|
122 |
+
unique_smiles = pd.concat([train_val_df, test_df])['Smiles'].unique().tolist()
|
123 |
+
smiles2fp = {s: np.array(pdp.get_fingerprint(s)) for s in unique_smiles}
|
124 |
+
|
125 |
+
# Get the CV object
|
126 |
+
if split_type == 'standard':
|
127 |
+
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
128 |
+
group = None
|
129 |
+
elif split_type == 'e3_ligase':
|
130 |
+
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
131 |
+
group = train_val_df['E3 Group'].to_numpy()
|
132 |
+
elif split_type == 'similarity':
|
133 |
+
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
134 |
+
group = train_val_df['Tanimoto Group'].to_numpy()
|
135 |
+
elif split_type == 'target':
|
136 |
+
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
137 |
+
group = train_val_df['Uniprot Group'].to_numpy()
|
138 |
+
|
139 |
+
# Start the experiment
|
140 |
+
experiment_name = f'{split_type}_{active_name}_test_split_{test_split}'
|
141 |
+
optuna_reports = pdp.hyperparameter_tuning_and_training(
|
142 |
+
protein2embedding=protein2embedding,
|
143 |
+
cell2embedding=cell2embedding,
|
144 |
+
smiles2fp=smiles2fp,
|
145 |
+
train_val_df=train_val_df,
|
146 |
+
test_df=test_df,
|
147 |
+
kf=kf,
|
148 |
+
groups=group,
|
149 |
+
split_type=split_type,
|
150 |
+
n_models_for_test=3,
|
151 |
+
fast_dev_run=fast_dev_run,
|
152 |
+
n_trials=n_trials,
|
153 |
+
max_epochs=max_epochs,
|
154 |
+
logger_save_dir='../logs',
|
155 |
+
logger_name=f'cellsonehot_aminoacidcnt_{experiment_name}',
|
156 |
+
active_label=active_col,
|
157 |
+
study_filename=f'../reports/study_cellsonehot_aminoacidcnt_{experiment_name}.pkl',
|
158 |
+
force_study=force_study,
|
159 |
+
)
|
160 |
+
|
161 |
+
# Save the reports to file
|
162 |
+
for report_name, report in optuna_reports.items():
|
163 |
+
report.to_csv(f'../reports/cellsonehot_aminoacidcnt_{report_name}_{experiment_name}.csv', index=False)
|
164 |
+
reports[report_name].append(report.copy())
|
165 |
+
|
166 |
+
|
167 |
+
if __name__ == '__main__':
|
168 |
+
cli = CLI(main)
|
src/{run_experiments.py → run_experiments_pytorch.py}
RENAMED
@@ -346,15 +346,15 @@ def main(
|
|
346 |
n_trials=n_trials,
|
347 |
max_epochs=max_epochs,
|
348 |
logger_save_dir='../logs',
|
349 |
-
logger_name=f'{experiment_name}',
|
350 |
active_label=active_col,
|
351 |
-
study_filename=f'../reports/
|
352 |
force_study=force_study,
|
353 |
)
|
354 |
|
355 |
# Save the reports to file
|
356 |
for report_name, report in optuna_reports.items():
|
357 |
-
report.to_csv(f'../reports/{report_name}_{experiment_name}.csv', index=False)
|
358 |
reports[report_name].append(report.copy())
|
359 |
|
360 |
|
|
|
346 |
n_trials=n_trials,
|
347 |
max_epochs=max_epochs,
|
348 |
logger_save_dir='../logs',
|
349 |
+
logger_name=f'pytorch_{experiment_name}',
|
350 |
active_label=active_col,
|
351 |
+
study_filename=f'../reports/study_pytorch_{experiment_name}.pkl',
|
352 |
force_study=force_study,
|
353 |
)
|
354 |
|
355 |
# Save the reports to file
|
356 |
for report_name, report in optuna_reports.items():
|
357 |
+
report.to_csv(f'../reports/pytorch_{report_name}_{experiment_name}.csv', index=False)
|
358 |
reports[report_name].append(report.copy())
|
359 |
|
360 |
|
src/run_experiments_xgboost.py
CHANGED
@@ -324,7 +324,7 @@ def main(
|
|
324 |
group = train_val_df['Uniprot Group'].to_numpy()
|
325 |
|
326 |
# Start the experiment
|
327 |
-
experiment_name = f'{active_name}_test_split_{test_split}
|
328 |
optuna_reports = pdp.xgboost_hyperparameter_tuning_and_training(
|
329 |
protein2embedding=protein2embedding,
|
330 |
cell2embedding=cell2embedding,
|
|
|
324 |
group = train_val_df['Uniprot Group'].to_numpy()
|
325 |
|
326 |
# Start the experiment
|
327 |
+
experiment_name = f'{split_type}_{active_name}_test_split_{test_split}'
|
328 |
optuna_reports = pdp.xgboost_hyperparameter_tuning_and_training(
|
329 |
protein2embedding=protein2embedding,
|
330 |
cell2embedding=cell2embedding,
|