ribesstefano commited on
Commit
82509b6
·
1 Parent(s): cf7560f

Updated experiment scripts

Browse files
src/{run_experiments_aminoacid_counts.py → run_experiments_aminoacidcnt.py} RENAMED
File without changes
src/{run_experiments_cells_onehot.py → run_experiments_cellsonehot.py} RENAMED
File without changes
src/run_experiments_cellsonehot_aminoacidcnt.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from collections import defaultdict
4
+ import warnings
5
+ import logging
6
+ from typing import Literal
7
+
8
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9
+
10
+ import protac_degradation_predictor as pdp
11
+
12
+ import pytorch_lightning as pl
13
+ from rdkit import Chem
14
+ from rdkit.Chem import AllChem
15
+ from rdkit import DataStructs
16
+ from jsonargparse import CLI
17
+ import pandas as pd
18
+ from tqdm import tqdm
19
+ import numpy as np
20
+ from sklearn.preprocessing import OrdinalEncoder
21
+ from sklearn.model_selection import (
22
+ StratifiedKFold,
23
+ StratifiedGroupKFold,
24
+ )
25
+ from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
26
+ from sklearn.feature_extraction.text import CountVectorizer
27
+
28
+ # Ignore UserWarning from Matplotlib
29
+ warnings.filterwarnings("ignore", ".*FixedLocator*")
30
+ # Ignore UserWarning from PyTorch Lightning
31
+ warnings.filterwarnings("ignore", ".*does not have many workers.*")
32
+
33
+
34
+ root = logging.getLogger()
35
+ root.setLevel(logging.DEBUG)
36
+
37
+ handler = logging.StreamHandler(sys.stdout)
38
+ handler.setLevel(logging.DEBUG)
39
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
40
+ handler.setFormatter(formatter)
41
+ root.addHandler(handler)
42
+
43
+ def main(
44
+ active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
45
+ n_trials: int = 100,
46
+ fast_dev_run: bool = False,
47
+ test_split: float = 0.1,
48
+ cv_n_splits: int = 5,
49
+ max_epochs: int = 100,
50
+ force_study: bool = False,
51
+ experiments: str | Literal['all', 'standard', 'e3_ligase', 'similarity', 'target'] = 'all',
52
+ ):
53
+ """ Run experiments with the cells one-hot encoding model.
54
+
55
+ Args:
56
+ active_col (str): Name of the column containing the active values.
57
+ n_trials (int): Number of hyperparameter optimization trials.
58
+ fast_dev_run (bool): Whether to run a fast development run.
59
+ test_split (float): Percentage of data to use for testing.
60
+ cv_n_splits (int): Number of cross-validation splits.
61
+ max_epochs (int): Maximum number of epochs to train the model.
62
+ force_study (bool): Whether to force the creation of a new study.
63
+ experiments (str): Type of experiments to run. Options are 'all', 'standard', 'e3_ligase', 'similarity', 'target'.
64
+ """
65
+ pl.seed_everything(42)
66
+
67
+ # Make directory ../reports if it does not exist
68
+ if not os.path.exists('../reports'):
69
+ os.makedirs('../reports')
70
+
71
+ # Load embedding dictionaries
72
+ protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
73
+ cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
74
+
75
+ # Get one-hot encoded embeddings for cell lines
76
+ onehotenc = OneHotEncoder(sparse_output=False)
77
+ cell_embeddings = onehotenc.fit_transform(
78
+ np.array(list(cell2embedding.keys())).reshape(-1, 1)
79
+ )
80
+ cell2embedding = {k: v for k, v in zip(cell2embedding.keys(), cell_embeddings)}
81
+
82
+ # Create a new protein2embedding dictionary with amino acid sequence
83
+ protac_df = pdp.load_curated_dataset()
84
+ # Create the dictionary mapping 'Uniprot' to 'POI Sequence'
85
+ protein2embedding = protac_df.set_index('Uniprot')['POI Sequence'].to_dict()
86
+ # Create the dictionary mapping 'E3 Ligase Uniprot' to 'E3 Ligase Sequence'
87
+ e32seq = protac_df.set_index('E3 Ligase Uniprot')['E3 Ligase Sequence'].to_dict()
88
+ # Merge the two dictionaries into a new protein2embedding dictionary
89
+ protein2embedding.update(e32seq)
90
+
91
+ # Get count vectorized embeddings for proteins
92
+ # NOTE: Check that the protein2embedding is a dictionary of strings
93
+ if not all(isinstance(k, str) for k in protein2embedding.keys()):
94
+ raise ValueError("All keys in `protein2embedding` must be strings.")
95
+ countvec = CountVectorizer(ngram_range=(1, 1), analyzer='char')
96
+ protein_embeddings = countvec.fit_transform(
97
+ list(protein2embedding.keys())
98
+ ).toarray()
99
+ protein2embedding = {k: v for k, v in zip(protein2embedding.keys(), protein_embeddings)}
100
+
101
+ studies_dir = '../data/studies'
102
+ train_val_perc = f'{int((1 - test_split) * 100)}'
103
+ test_perc = f'{int(test_split * 100)}'
104
+ active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
105
+
106
+ if experiments == 'all':
107
+ experiments = ['standard', 'similarity', 'target']
108
+ else:
109
+ experiments = [experiments]
110
+
111
+ # Cross-Validation Training
112
+ reports = defaultdict(list)
113
+ for split_type in experiments:
114
+
115
+ train_val_filename = f'{split_type}_train_val_{train_val_perc}split_{active_name}.csv'
116
+ test_filename = f'{split_type}_test_{test_perc}split_{active_name}.csv'
117
+
118
+ train_val_df = pd.read_csv(os.path.join(studies_dir, train_val_filename))
119
+ test_df = pd.read_csv(os.path.join(studies_dir, test_filename))
120
+
121
+ # Get SMILES and precompute fingerprints dictionary
122
+ unique_smiles = pd.concat([train_val_df, test_df])['Smiles'].unique().tolist()
123
+ smiles2fp = {s: np.array(pdp.get_fingerprint(s)) for s in unique_smiles}
124
+
125
+ # Get the CV object
126
+ if split_type == 'standard':
127
+ kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
128
+ group = None
129
+ elif split_type == 'e3_ligase':
130
+ kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
131
+ group = train_val_df['E3 Group'].to_numpy()
132
+ elif split_type == 'similarity':
133
+ kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
134
+ group = train_val_df['Tanimoto Group'].to_numpy()
135
+ elif split_type == 'target':
136
+ kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
137
+ group = train_val_df['Uniprot Group'].to_numpy()
138
+
139
+ # Start the experiment
140
+ experiment_name = f'{split_type}_{active_name}_test_split_{test_split}'
141
+ optuna_reports = pdp.hyperparameter_tuning_and_training(
142
+ protein2embedding=protein2embedding,
143
+ cell2embedding=cell2embedding,
144
+ smiles2fp=smiles2fp,
145
+ train_val_df=train_val_df,
146
+ test_df=test_df,
147
+ kf=kf,
148
+ groups=group,
149
+ split_type=split_type,
150
+ n_models_for_test=3,
151
+ fast_dev_run=fast_dev_run,
152
+ n_trials=n_trials,
153
+ max_epochs=max_epochs,
154
+ logger_save_dir='../logs',
155
+ logger_name=f'cellsonehot_aminoacidcnt_{experiment_name}',
156
+ active_label=active_col,
157
+ study_filename=f'../reports/study_cellsonehot_aminoacidcnt_{experiment_name}.pkl',
158
+ force_study=force_study,
159
+ )
160
+
161
+ # Save the reports to file
162
+ for report_name, report in optuna_reports.items():
163
+ report.to_csv(f'../reports/cellsonehot_aminoacidcnt_{report_name}_{experiment_name}.csv', index=False)
164
+ reports[report_name].append(report.copy())
165
+
166
+
167
+ if __name__ == '__main__':
168
+ cli = CLI(main)
src/{run_experiments.py → run_experiments_pytorch.py} RENAMED
@@ -346,15 +346,15 @@ def main(
346
  n_trials=n_trials,
347
  max_epochs=max_epochs,
348
  logger_save_dir='../logs',
349
- logger_name=f'{experiment_name}',
350
  active_label=active_col,
351
- study_filename=f'../reports/study_{experiment_name}.pkl',
352
  force_study=force_study,
353
  )
354
 
355
  # Save the reports to file
356
  for report_name, report in optuna_reports.items():
357
- report.to_csv(f'../reports/{report_name}_{experiment_name}.csv', index=False)
358
  reports[report_name].append(report.copy())
359
 
360
 
 
346
  n_trials=n_trials,
347
  max_epochs=max_epochs,
348
  logger_save_dir='../logs',
349
+ logger_name=f'pytorch_{experiment_name}',
350
  active_label=active_col,
351
+ study_filename=f'../reports/study_pytorch_{experiment_name}.pkl',
352
  force_study=force_study,
353
  )
354
 
355
  # Save the reports to file
356
  for report_name, report in optuna_reports.items():
357
+ report.to_csv(f'../reports/pytorch_{report_name}_{experiment_name}.csv', index=False)
358
  reports[report_name].append(report.copy())
359
 
360
 
src/run_experiments_xgboost.py CHANGED
@@ -324,7 +324,7 @@ def main(
324
  group = train_val_df['Uniprot Group'].to_numpy()
325
 
326
  # Start the experiment
327
- experiment_name = f'{active_name}_test_split_{test_split}_{split_type}'
328
  optuna_reports = pdp.xgboost_hyperparameter_tuning_and_training(
329
  protein2embedding=protein2embedding,
330
  cell2embedding=cell2embedding,
 
324
  group = train_val_df['Uniprot Group'].to_numpy()
325
 
326
  # Start the experiment
327
+ experiment_name = f'{split_type}_{active_name}_test_split_{test_split}'
328
  optuna_reports = pdp.xgboost_hyperparameter_tuning_and_training(
329
  protein2embedding=protein2embedding,
330
  cell2embedding=cell2embedding,