ribesstefano
commited on
Commit
•
d36ec1d
1
Parent(s):
87b14e7
Polished scripted version of Hparam-CV training
Browse files
notebooks/protac_degradation_predictor.py
CHANGED
@@ -1,47 +1,75 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
|
|
5 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
|
8 |
protac_df.head()
|
9 |
|
10 |
-
# %%
|
11 |
# Get the unique Article IDs of the entries with NaN values in the Active column
|
12 |
nan_active = protac_df[protac_df['Active'].isna()]['Article DOI'].unique()
|
13 |
nan_active
|
14 |
|
15 |
-
# %%
|
16 |
# Map E3 Ligase Iap to IAP
|
17 |
protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
|
18 |
|
19 |
-
# %%
|
20 |
-
protac_df.columns
|
21 |
-
|
22 |
-
# %%
|
23 |
cells = sorted(protac_df['Cell Type'].dropna().unique().tolist())
|
24 |
print(f'Number of non-cleaned cell lines: {len(cells)}')
|
25 |
|
26 |
-
# %%
|
27 |
cells = sorted(protac_df['Cell Line Identifier'].dropna().unique().tolist())
|
28 |
print(f'Number of cleaned cell lines: {len(cells)}')
|
29 |
|
30 |
-
# %%
|
31 |
unlabeled_df = protac_df[protac_df['Active'].isna()]
|
32 |
print(f'Number of compounds in test set: {len(unlabeled_df)}')
|
33 |
|
34 |
-
# %% [markdown]
|
35 |
# ## Load Protein Embeddings
|
36 |
|
37 |
-
# %% [markdown]
|
38 |
# Protein embeddings downloaded from [Uniprot](https://www.uniprot.org/help/embeddings).
|
39 |
#
|
40 |
# Please note that running the following cell the first time might take a while.
|
41 |
|
42 |
-
# %%
|
43 |
-
import os
|
44 |
-
import urllib.request
|
45 |
|
46 |
download_link = "https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/embeddings/UP000005640_9606/per-protein.h5"
|
47 |
embeddings_path = "../data/uniprot2embedding.h5"
|
@@ -50,11 +78,6 @@ if not os.path.exists(embeddings_path):
|
|
50 |
print(f'Downloading embeddings from {download_link}')
|
51 |
urllib.request.urlretrieve(download_link, embeddings_path)
|
52 |
|
53 |
-
# %%
|
54 |
-
import h5py
|
55 |
-
import numpy as np
|
56 |
-
from tqdm.auto import tqdm
|
57 |
-
|
58 |
protein_embeddings = {}
|
59 |
with h5py.File("../data/uniprot2embedding.h5", "r") as file:
|
60 |
print(f"number of entries: {len(file.items()):,}")
|
@@ -74,37 +97,27 @@ with h5py.File("../data/uniprot2embedding.h5", "r") as file:
|
|
74 |
print(f'KeyError for {sequence_id}')
|
75 |
protein_embeddings[sequence_id] = np.zeros((1024,))
|
76 |
|
77 |
-
# %% [markdown]
|
78 |
# ## Load Cell Embeddings
|
79 |
|
80 |
-
# %%
|
81 |
-
import pickle
|
82 |
|
83 |
cell2embedding_filepath = '../data/cell2embedding.pkl'
|
84 |
with open(cell2embedding_filepath, 'rb') as f:
|
85 |
cell2embedding = pickle.load(f)
|
86 |
print(f'Loaded {len(cell2embedding)} cell lines')
|
87 |
|
88 |
-
# %%
|
89 |
emb_shape = cell2embedding[list(cell2embedding.keys())[0]].shape
|
90 |
# Assign all-zero vectors to cell lines that are not in the embedding file
|
91 |
for cell_line in protac_df['Cell Line Identifier'].unique():
|
92 |
if cell_line not in cell2embedding:
|
93 |
cell2embedding[cell_line] = np.zeros(emb_shape)
|
94 |
|
95 |
-
# %% [markdown]
|
96 |
# ## Precompute Molecular Fingerprints
|
97 |
-
|
98 |
-
# %%
|
99 |
-
from rdkit import Chem
|
100 |
-
from rdkit.Chem import AllChem
|
101 |
-
from rdkit.Chem import Draw
|
102 |
-
|
103 |
morgan_radius = 15
|
104 |
n_bits = 1024
|
105 |
|
106 |
# fpgen = AllChem.GetAtomPairGenerator()
|
107 |
-
rdkit_fpgen = AllChem.GetRDKitFPGenerator(maxPath=5, fpSize=512)
|
108 |
morgan_fpgen = AllChem.GetMorganGenerator(radius=morgan_radius, fpSize=n_bits)
|
109 |
|
110 |
smiles2fp = {}
|
@@ -129,7 +142,6 @@ for smiles, fp in smiles2fp.items():
|
|
129 |
print(f'Number of SMILES with overlapping fingerprints: {len(overlapping_smiles)}')
|
130 |
print(f'Number of overlapping SMILES in protac_df: {len(protac_df[protac_df["Smiles"].isin(overlapping_smiles)])}')
|
131 |
|
132 |
-
# %%
|
133 |
# Get the pair-wise tanimoto similarity between the PROTAC fingerprints
|
134 |
from rdkit import DataStructs
|
135 |
from collections import defaultdict
|
@@ -147,44 +159,14 @@ for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing T
|
|
147 |
avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
|
148 |
protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
|
149 |
|
150 |
-
# %%
|
151 |
-
# # Plot the distribution of the average Tanimoto similarity
|
152 |
-
# import seaborn as sns
|
153 |
-
# import matplotlib.pyplot as plt
|
154 |
-
|
155 |
-
# sns.histplot(protac_df['Avg Tanimoto'], bins=50)
|
156 |
-
# plt.xlabel('Average Tanimoto similarity')
|
157 |
-
# plt.ylabel('Count')
|
158 |
-
# plt.title('Distribution of average Tanimoto similarity')
|
159 |
-
# plt.grid(axis='y', alpha=0.5)
|
160 |
-
# plt.show()
|
161 |
-
|
162 |
-
# %%
|
163 |
smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
|
164 |
|
165 |
-
# %% [markdown]
|
166 |
# ## Set the Column to Predict
|
167 |
|
168 |
-
# %%
|
169 |
# active_col = 'Active'
|
170 |
active_col = 'Active - OR'
|
171 |
|
172 |
|
173 |
-
from sklearn.preprocessing import StandardScaler
|
174 |
-
|
175 |
-
# %% [markdown]
|
176 |
-
# ## Define Torch Dataset
|
177 |
-
|
178 |
-
# %%
|
179 |
-
from imblearn.over_sampling import SMOTE, ADASYN
|
180 |
-
from sklearn.preprocessing import LabelEncoder
|
181 |
-
import pandas as pd
|
182 |
-
import numpy as np
|
183 |
-
|
184 |
-
# %%
|
185 |
-
from torch.utils.data import Dataset, DataLoader
|
186 |
-
|
187 |
-
|
188 |
class PROTAC_Dataset(Dataset):
|
189 |
def __init__(
|
190 |
self,
|
@@ -295,22 +277,6 @@ class PROTAC_Dataset(Dataset):
|
|
295 |
}
|
296 |
return elem
|
297 |
|
298 |
-
# %%
|
299 |
-
import warnings
|
300 |
-
import torch
|
301 |
-
import torch.nn as nn
|
302 |
-
import torch.nn.functional as F
|
303 |
-
import torch.optim as optim
|
304 |
-
import pytorch_lightning as pl
|
305 |
-
from torchmetrics import (
|
306 |
-
Accuracy,
|
307 |
-
AUROC,
|
308 |
-
Precision,
|
309 |
-
Recall,
|
310 |
-
F1Score,
|
311 |
-
)
|
312 |
-
from torchmetrics import MetricCollection
|
313 |
-
|
314 |
# Ignore UserWarning from PyTorch Lightning
|
315 |
warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
316 |
|
@@ -505,22 +471,17 @@ class PROTAC_Model(pl.LightningModule):
|
|
505 |
shuffle=False,
|
506 |
)
|
507 |
|
508 |
-
# %% [markdown]
|
509 |
# ## Test Sets
|
510 |
|
511 |
-
# %% [markdown]
|
512 |
# We want a different test set per Cross-Validation (CV) experiment (see further down). We are interested in three scenarios:
|
513 |
# * Randomly splitting the data into training and test sets. Hence, the test st shall contain unique SMILES and Uniprots
|
514 |
# * Splitting the data according to their Uniprot. Hence, the test set shall contain unique Uniprots
|
515 |
# * Splitting the data according to their SMILES, _i.e._, the test set shall contain unique SMILES
|
516 |
|
517 |
-
# %%
|
518 |
test_indeces = {}
|
519 |
|
520 |
-
# %% [markdown]
|
521 |
# Isolating the unique SMILES and Uniprots:
|
522 |
|
523 |
-
# %%
|
524 |
active_df = protac_df[protac_df[active_col].notna()].copy()
|
525 |
|
526 |
# Get the unique SMILES and Uniprot
|
@@ -552,10 +513,8 @@ test_indeces['random'] = unique_indices
|
|
552 |
# plt.title('Test set Active - OR distribution')
|
553 |
# plt.show()
|
554 |
|
555 |
-
# %% [markdown]
|
556 |
# Isolating the unique Uniprots:
|
557 |
|
558 |
-
# %%
|
559 |
active_df = protac_df[protac_df[active_col].notna()].copy()
|
560 |
|
561 |
unique_uniprot = active_df['Uniprot'].value_counts() == 1
|
@@ -569,13 +528,11 @@ unique_indices = active_df[active_df['Uniprot'].isin(unique_uniprot.index)].inde
|
|
569 |
test_indeces['uniprot'] = unique_indices
|
570 |
print(f'Number of unique indices: {len(unique_indices)} ({len(unique_indices) / len(active_df):.1%})')
|
571 |
|
572 |
-
# %% [markdown]
|
573 |
# DEPRECATED: The following results in a too Before starting any training, we isolate a small group of test data. Each element in the test set is selected so that all the following conditions are met:
|
574 |
# * its SMILES is unique
|
575 |
# * its POI is unique
|
576 |
# * its (SMILES, POI) pair is unique
|
577 |
|
578 |
-
# %%
|
579 |
active_df = protac_df[protac_df[active_col].notna()]
|
580 |
|
581 |
# Find the samples that:
|
@@ -598,25 +555,16 @@ test_df = active_df.loc[unique_samples]
|
|
598 |
|
599 |
warnings.filterwarnings("ignore", ".*FixedLocator*")
|
600 |
|
601 |
-
# %% [markdown]
|
602 |
# ## Cross-Validation Training
|
603 |
|
604 |
-
# %% [markdown]
|
605 |
# Cross validation training with 5 splits. The split operation is done in three different ways:
|
606 |
#
|
607 |
# * Random split
|
608 |
# * POI-wise: some POIs never in both splits
|
609 |
# * Least Tanimoto similarity PROTAC-wise
|
610 |
|
611 |
-
# %% [markdown]
|
612 |
# ### Plotting CV Folds
|
613 |
|
614 |
-
# %%
|
615 |
-
from sklearn.model_selection import (
|
616 |
-
StratifiedKFold,
|
617 |
-
StratifiedGroupKFold,
|
618 |
-
)
|
619 |
-
from sklearn.preprocessing import OrdinalEncoder
|
620 |
|
621 |
# NOTE: When set to 60, it will result in 29 groups, with nice distributions of
|
622 |
# the number of unique groups in the train and validation sets, together with
|
@@ -688,10 +636,8 @@ for group_type in groups:
|
|
688 |
stats.append(stat)
|
689 |
print('-' * 120)
|
690 |
|
691 |
-
# %% [markdown]
|
692 |
# ### Run CV
|
693 |
|
694 |
-
# %%
|
695 |
import warnings
|
696 |
|
697 |
# Seed everything in pytorch lightning
|
@@ -805,14 +751,8 @@ def train_model(
|
|
805 |
metrics.update(test_metrics)
|
806 |
return model, trainer, metrics
|
807 |
|
808 |
-
# %% [markdown]
|
809 |
# Setup hyperparameter optimization:
|
810 |
|
811 |
-
# %%
|
812 |
-
import optuna
|
813 |
-
import pandas as pd
|
814 |
-
|
815 |
-
|
816 |
def objective(
|
817 |
trial,
|
818 |
train_df,
|
@@ -926,10 +866,8 @@ def hyperparameter_tuning_and_training(
|
|
926 |
# train_df, val_df, test_df = load_your_data() # You need to load your datasets here
|
927 |
# model, trainer, best_metrics = hyperparameter_tuning_and_training(train_df, val_df, test_df)
|
928 |
|
929 |
-
# %% [markdown]
|
930 |
# Loop over the different splits and train the model:
|
931 |
|
932 |
-
# %%
|
933 |
n_splits = 5
|
934 |
report = []
|
935 |
active_df = protac_df[protac_df[active_col].notna()]
|
@@ -998,4 +936,3 @@ report = pd.DataFrame(report)
|
|
998 |
report.to_csv(
|
999 |
f'../reports/cv_report_hparam_search_{n_splits}-splits.csv', index=False,
|
1000 |
)
|
1001 |
-
|
|
|
1 |
+
import optuna
|
2 |
+
import pandas as pd
|
3 |
+
from rdkit import Chem
|
4 |
+
from rdkit.Chem import AllChem
|
5 |
+
|
6 |
+
import h5py
|
7 |
+
import numpy as np
|
8 |
+
from tqdm.auto import tqdm
|
9 |
+
|
10 |
+
import os
|
11 |
+
import urllib.request
|
12 |
+
|
13 |
+
from sklearn.preprocessing import StandardScaler
|
14 |
+
|
15 |
+
# ## Define Torch Dataset
|
16 |
|
17 |
+
from imblearn.over_sampling import SMOTE, ADASYN
|
18 |
+
from sklearn.preprocessing import LabelEncoder
|
19 |
import pandas as pd
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
from torch.utils.data import Dataset, DataLoader
|
23 |
+
|
24 |
+
import warnings
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
import torch.nn.functional as F
|
28 |
+
import torch.optim as optim
|
29 |
+
import pytorch_lightning as pl
|
30 |
+
from torchmetrics import (
|
31 |
+
Accuracy,
|
32 |
+
AUROC,
|
33 |
+
Precision,
|
34 |
+
Recall,
|
35 |
+
F1Score,
|
36 |
+
)
|
37 |
+
from torchmetrics import MetricCollection
|
38 |
+
|
39 |
+
import pickle
|
40 |
+
|
41 |
+
from sklearn.model_selection import (
|
42 |
+
StratifiedKFold,
|
43 |
+
StratifiedGroupKFold,
|
44 |
+
)
|
45 |
+
from sklearn.preprocessing import OrdinalEncoder
|
46 |
+
|
47 |
|
48 |
protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
|
49 |
protac_df.head()
|
50 |
|
|
|
51 |
# Get the unique Article IDs of the entries with NaN values in the Active column
|
52 |
nan_active = protac_df[protac_df['Active'].isna()]['Article DOI'].unique()
|
53 |
nan_active
|
54 |
|
|
|
55 |
# Map E3 Ligase Iap to IAP
|
56 |
protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
|
57 |
|
|
|
|
|
|
|
|
|
58 |
cells = sorted(protac_df['Cell Type'].dropna().unique().tolist())
|
59 |
print(f'Number of non-cleaned cell lines: {len(cells)}')
|
60 |
|
|
|
61 |
cells = sorted(protac_df['Cell Line Identifier'].dropna().unique().tolist())
|
62 |
print(f'Number of cleaned cell lines: {len(cells)}')
|
63 |
|
|
|
64 |
unlabeled_df = protac_df[protac_df['Active'].isna()]
|
65 |
print(f'Number of compounds in test set: {len(unlabeled_df)}')
|
66 |
|
|
|
67 |
# ## Load Protein Embeddings
|
68 |
|
|
|
69 |
# Protein embeddings downloaded from [Uniprot](https://www.uniprot.org/help/embeddings).
|
70 |
#
|
71 |
# Please note that running the following cell the first time might take a while.
|
72 |
|
|
|
|
|
|
|
73 |
|
74 |
download_link = "https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/embeddings/UP000005640_9606/per-protein.h5"
|
75 |
embeddings_path = "../data/uniprot2embedding.h5"
|
|
|
78 |
print(f'Downloading embeddings from {download_link}')
|
79 |
urllib.request.urlretrieve(download_link, embeddings_path)
|
80 |
|
|
|
|
|
|
|
|
|
|
|
81 |
protein_embeddings = {}
|
82 |
with h5py.File("../data/uniprot2embedding.h5", "r") as file:
|
83 |
print(f"number of entries: {len(file.items()):,}")
|
|
|
97 |
print(f'KeyError for {sequence_id}')
|
98 |
protein_embeddings[sequence_id] = np.zeros((1024,))
|
99 |
|
|
|
100 |
# ## Load Cell Embeddings
|
101 |
|
|
|
|
|
102 |
|
103 |
cell2embedding_filepath = '../data/cell2embedding.pkl'
|
104 |
with open(cell2embedding_filepath, 'rb') as f:
|
105 |
cell2embedding = pickle.load(f)
|
106 |
print(f'Loaded {len(cell2embedding)} cell lines')
|
107 |
|
|
|
108 |
emb_shape = cell2embedding[list(cell2embedding.keys())[0]].shape
|
109 |
# Assign all-zero vectors to cell lines that are not in the embedding file
|
110 |
for cell_line in protac_df['Cell Line Identifier'].unique():
|
111 |
if cell_line not in cell2embedding:
|
112 |
cell2embedding[cell_line] = np.zeros(emb_shape)
|
113 |
|
|
|
114 |
# ## Precompute Molecular Fingerprints
|
115 |
+
|
|
|
|
|
|
|
|
|
|
|
116 |
morgan_radius = 15
|
117 |
n_bits = 1024
|
118 |
|
119 |
# fpgen = AllChem.GetAtomPairGenerator()
|
120 |
+
# rdkit_fpgen = AllChem.GetRDKitFPGenerator(maxPath=5, fpSize=512)
|
121 |
morgan_fpgen = AllChem.GetMorganGenerator(radius=morgan_radius, fpSize=n_bits)
|
122 |
|
123 |
smiles2fp = {}
|
|
|
142 |
print(f'Number of SMILES with overlapping fingerprints: {len(overlapping_smiles)}')
|
143 |
print(f'Number of overlapping SMILES in protac_df: {len(protac_df[protac_df["Smiles"].isin(overlapping_smiles)])}')
|
144 |
|
|
|
145 |
# Get the pair-wise tanimoto similarity between the PROTAC fingerprints
|
146 |
from rdkit import DataStructs
|
147 |
from collections import defaultdict
|
|
|
159 |
avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
|
160 |
protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
|
163 |
|
|
|
164 |
# ## Set the Column to Predict
|
165 |
|
|
|
166 |
# active_col = 'Active'
|
167 |
active_col = 'Active - OR'
|
168 |
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
class PROTAC_Dataset(Dataset):
|
171 |
def __init__(
|
172 |
self,
|
|
|
277 |
}
|
278 |
return elem
|
279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
# Ignore UserWarning from PyTorch Lightning
|
281 |
warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
282 |
|
|
|
471 |
shuffle=False,
|
472 |
)
|
473 |
|
|
|
474 |
# ## Test Sets
|
475 |
|
|
|
476 |
# We want a different test set per Cross-Validation (CV) experiment (see further down). We are interested in three scenarios:
|
477 |
# * Randomly splitting the data into training and test sets. Hence, the test st shall contain unique SMILES and Uniprots
|
478 |
# * Splitting the data according to their Uniprot. Hence, the test set shall contain unique Uniprots
|
479 |
# * Splitting the data according to their SMILES, _i.e._, the test set shall contain unique SMILES
|
480 |
|
|
|
481 |
test_indeces = {}
|
482 |
|
|
|
483 |
# Isolating the unique SMILES and Uniprots:
|
484 |
|
|
|
485 |
active_df = protac_df[protac_df[active_col].notna()].copy()
|
486 |
|
487 |
# Get the unique SMILES and Uniprot
|
|
|
513 |
# plt.title('Test set Active - OR distribution')
|
514 |
# plt.show()
|
515 |
|
|
|
516 |
# Isolating the unique Uniprots:
|
517 |
|
|
|
518 |
active_df = protac_df[protac_df[active_col].notna()].copy()
|
519 |
|
520 |
unique_uniprot = active_df['Uniprot'].value_counts() == 1
|
|
|
528 |
test_indeces['uniprot'] = unique_indices
|
529 |
print(f'Number of unique indices: {len(unique_indices)} ({len(unique_indices) / len(active_df):.1%})')
|
530 |
|
|
|
531 |
# DEPRECATED: The following results in a too Before starting any training, we isolate a small group of test data. Each element in the test set is selected so that all the following conditions are met:
|
532 |
# * its SMILES is unique
|
533 |
# * its POI is unique
|
534 |
# * its (SMILES, POI) pair is unique
|
535 |
|
|
|
536 |
active_df = protac_df[protac_df[active_col].notna()]
|
537 |
|
538 |
# Find the samples that:
|
|
|
555 |
|
556 |
warnings.filterwarnings("ignore", ".*FixedLocator*")
|
557 |
|
|
|
558 |
# ## Cross-Validation Training
|
559 |
|
|
|
560 |
# Cross validation training with 5 splits. The split operation is done in three different ways:
|
561 |
#
|
562 |
# * Random split
|
563 |
# * POI-wise: some POIs never in both splits
|
564 |
# * Least Tanimoto similarity PROTAC-wise
|
565 |
|
|
|
566 |
# ### Plotting CV Folds
|
567 |
|
|
|
|
|
|
|
|
|
|
|
|
|
568 |
|
569 |
# NOTE: When set to 60, it will result in 29 groups, with nice distributions of
|
570 |
# the number of unique groups in the train and validation sets, together with
|
|
|
636 |
stats.append(stat)
|
637 |
print('-' * 120)
|
638 |
|
|
|
639 |
# ### Run CV
|
640 |
|
|
|
641 |
import warnings
|
642 |
|
643 |
# Seed everything in pytorch lightning
|
|
|
751 |
metrics.update(test_metrics)
|
752 |
return model, trainer, metrics
|
753 |
|
|
|
754 |
# Setup hyperparameter optimization:
|
755 |
|
|
|
|
|
|
|
|
|
|
|
756 |
def objective(
|
757 |
trial,
|
758 |
train_df,
|
|
|
866 |
# train_df, val_df, test_df = load_your_data() # You need to load your datasets here
|
867 |
# model, trainer, best_metrics = hyperparameter_tuning_and_training(train_df, val_df, test_df)
|
868 |
|
|
|
869 |
# Loop over the different splits and train the model:
|
870 |
|
|
|
871 |
n_splits = 5
|
872 |
report = []
|
873 |
active_df = protac_df[protac_df[active_col].notna()]
|
|
|
936 |
report.to_csv(
|
937 |
f'../reports/cv_report_hparam_search_{n_splits}-splits.csv', index=False,
|
938 |
)
|
|