File size: 14,364 Bytes
33f1644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a589b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33f1644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e811f2
33f1644
a589b70
 
 
33f1644
8315295
33f1644
 
 
 
a589b70
 
33f1644
 
 
 
1e811f2
a589b70
 
 
 
33f1644
8315295
33f1644
 
 
 
a589b70
 
33f1644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1171189
33f1644
 
 
 
 
 
 
 
 
 
1171189
33f1644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1171189
33f1644
 
 
 
 
 
 
 
 
 
 
 
1171189
 
 
 
 
33f1644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1171189
33f1644
 
 
 
 
 
 
 
1171189
33f1644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1171189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33f1644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1171189
 
 
a589b70
1171189
 
33f1644
 
 
 
1e811f2
33f1644
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
from typing import Optional, Dict
import logging
import os

from .optuna_utils import get_majority_vote_metrics, get_dataframe_stats
from .protac_dataset import get_datasets

import optuna
import xgboost as xgb
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
import xgboost as xgb
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
import joblib
from optuna.samplers import TPESampler
import torch


xgb.set_config(verbosity=0)


def get_confidence_scores(y, y_pred, threshold=0.5):
    # Calculate the likelihood for the false negative: get the mean value of
    # the prediction for the false-positive and false-negatives

    # Get the indices of the false positives and false negatives
    false_positives = (y == 0) & ((y_pred > threshold).astype(int) == 1)
    false_negatives = (y == 1) & ((y_pred > threshold).astype(int) == 0)

    # Get the mean value of the predictions for the false positives and false negatives
    false_positives_mean = y_pred[false_positives].mean()
    false_negatives_mean = y_pred[false_negatives].mean()

    return false_positives_mean, false_negatives_mean


def train_and_evaluate_xgboost(
        protein2embedding: Dict,
        cell2embedding: Dict,
        smiles2fp: Dict,
        train_df: pd.DataFrame,
        val_df: pd.DataFrame,
        params: dict,
        test_df: Optional[pd.DataFrame] = None,
        active_label: str = 'Active',
        num_boost_round: int = 100,
        shuffle_train_data: bool = False,
) -> tuple:
    """
    Train and evaluate an XGBoost model with the given parameters.
    
    Args:
        train_df (pd.DataFrame): The training and validation data.
        test_df (pd.DataFrame): The test data.
        params (dict): Hyperparameters for the XGBoost model.
        active_label (str): The active label column.
        num_boost_round (int): Maximum number of epochs.

    Returns:
        tuple: The trained model, test predictions, and metrics.
    """
    # Get datasets and their numpy arrays
    train_ds, val_ds, test_ds  = get_datasets(
        protein2embedding=protein2embedding,
        cell2embedding=cell2embedding,
        smiles2fp=smiles2fp,
        train_df=train_df,
        val_df=val_df,
        test_df=test_df,
        disabled_embeddings=[],
        active_label=active_label,
        apply_scaling=False,
    )
    X_train, y_train = train_ds.get_numpy_arrays()
    X_val, y_val = val_ds.get_numpy_arrays()

    # Shuffle the training data
    if shuffle_train_data:
        idx = np.random.permutation(len(X_train))
        X_train, y_train = X_train[idx], y_train[idx]
 
    # Setup training and validation data in XGBoost data format
    dtrain = xgb.DMatrix(X_train, label=y_train)
    dval = xgb.DMatrix(X_val, label=y_val)
    evallist = [(dval, 'eval'), (dtrain, 'train')]
 
    # Setup test data
    if test_df is not None:
        X_test, y_test = test_ds.get_numpy_arrays()
        dtest = xgb.DMatrix(X_test, label=y_test)
        evallist.append((dtest, 'test'))

    model = xgb.train(
        params,
        dtrain,
        num_boost_round=num_boost_round,
        evals=evallist,
        early_stopping_rounds=10,
        verbose_eval=False,
    )

    # Evaluate model
    val_pred = model.predict(dval)
    val_pred_binary = (val_pred > 0.5).astype(int)

    fp_mean, fn_mean = get_confidence_scores(y_val, val_pred)

    metrics = {
        'val_acc': accuracy_score(y_val, val_pred_binary),
        'val_roc_auc': roc_auc_score(y_val, val_pred),
        'val_precision': precision_score(y_val, val_pred_binary),
        'val_recall': recall_score(y_val, val_pred_binary),
        'val_f1_score': f1_score(y_val, val_pred_binary),
        'val_false_positives_mean': fp_mean,
        'val_false_negatives_mean': fn_mean,
    }
    preds = {'val_pred': val_pred}

    if test_df is not None:
        test_pred = model.predict(dtest)
        test_pred_binary = (test_pred > 0.5).astype(int)

        fp_mean, fn_mean = get_confidence_scores(y_test, test_pred)

        metrics.update({
            'test_acc': accuracy_score(y_test, test_pred_binary),
            'test_roc_auc': roc_auc_score(y_test, test_pred),
            'test_precision': precision_score(y_test, test_pred_binary),
            'test_recall': recall_score(y_test, test_pred_binary),
            'test_f1_score': f1_score(y_test, test_pred_binary),
            'test_false_positives_mean': fp_mean,
            'test_false_negatives_mean': fn_mean,
        })
        preds.update({'test_pred': test_pred})
    
    return model, preds, metrics


def xgboost_model_objective(
        trial: optuna.Trial,
        protein2embedding: Dict,
        cell2embedding: Dict,
        smiles2fp: Dict,
        train_val_df: pd.DataFrame,
        kf: StratifiedKFold,
        groups: Optional[np.array] = None,
        active_label: str = 'Active',
        num_boost_round: int = 100,
        model_name: Optional[str] = None,
) -> float:
    """ Objective function for hyperparameter optimization with XGBoost.
    
    Args:
        trial (optuna.Trial): The Optuna trial object.
        train_val_df (pd.DataFrame): The training and validation data.
        kf (StratifiedKFold): Stratified K-Folds cross-validator.
        test_df (Optional[pd.DataFrame]): The test data.
        active_label (str): The active label column.
        num_boost_round (int): Maximum number of epochs.
        model_name (Optional[str]): The prefix name of the CV models to save, if supplied. Used as: `f"{model_name}_fold_{k}.json"`
    """
    # Suggest hyperparameters to be used across the CV folds
    params = {
        'booster': 'gbtree',
        'tree_method': 'hist', # if torch.cuda.is_available() else 'hist',
        'objective': 'binary:logistic',
        'eval_metric': 'auc',
        'eta': trial.suggest_float('eta', 1e-4, 1e-1, log=True),
        'max_depth': trial.suggest_int('max_depth', 3, 10),
        'min_child_weight': trial.suggest_float('min_child_weight', 1e-3, 10.0, log=True),
        'gamma': trial.suggest_float('gamma', 1e-4, 1e-1, log=True),
        'subsample': trial.suggest_float('subsample', 0.5, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
    }
    
    X = train_val_df.copy().drop(columns=active_label)
    y = train_val_df[active_label].tolist()
    report = []
    val_preds = []

    for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
        logging.info(f'Fold {k + 1}/{kf.get_n_splits()}')
        train_df = train_val_df.iloc[train_index]
        val_df = train_val_df.iloc[val_index]

        # Get some statistics from the dataframes
        stats = {
            'model_type': 'XGBoost',
            'fold': k,
            'train_len': len(train_df),
            'val_len': len(val_df),
            'train_perc': len(train_df) / len(train_val_df),
            'val_perc': len(val_df) / len(train_val_df),
        }
        stats.update(get_dataframe_stats(train_df, val_df, active_label=active_label))
        if groups is not None:
            stats['train_unique_groups'] = len(np.unique(groups[train_index]))
            stats['val_unique_groups'] = len(np.unique(groups[val_index]))

        bst, preds, metrics = train_and_evaluate_xgboost(
            protein2embedding=protein2embedding,
            cell2embedding=cell2embedding,
            smiles2fp=smiles2fp,
            train_df=train_df,
            val_df=val_df,
            params=params,
            active_label=active_label,
            num_boost_round=num_boost_round,
        )
        stats.update(metrics)
        report.append(stats.copy())
        val_preds.append(preds['val_pred'])

        if model_name:
            model_filename = f'{model_name}_fold{k}.json'
            bst.save_model(model_filename)
            logging.info(f'CV XGBoost model saved to: {model_filename}')
    
    # Save the report in the trial
    trial.set_user_attr('report', report)
    trial.set_user_attr('val_preds', val_preds)
    trial.set_user_attr('params', params)
    
    # Get the average validation metrics across the folds
    mean_val_roc_auc = np.mean([r['val_roc_auc'] for r in report])
    logging.info(f'\tMean val ROC AUC: {mean_val_roc_auc:.4f}')
    
    # Optuna aims to minimize the objective, so return the negative ROC AUC
    return -mean_val_roc_auc


def xgboost_hyperparameter_tuning_and_training(
        protein2embedding: Dict,
        cell2embedding: Dict,
        smiles2fp: Dict,
        train_val_df: pd.DataFrame,
        test_df: pd.DataFrame,
        kf: StratifiedKFold,
        groups: Optional[np.array] = None,
        split_type: str = 'random',
        n_models_for_test: int = 3,
        n_trials: int = 50,
        active_label: str = 'Active',
        num_boost_round: int = 100,
        study_filename: Optional[str] = None,
        force_study: bool = False,
        model_name: Optional[str] = None,
) -> dict:
    """ Hyperparameter tuning and training of an XGBoost model.
    
    Args:
        train_val_df (pd.DataFrame): The training and validation data.
        test_df (pd.DataFrame): The test data.
        kf (StratifiedKFold): Stratified K-Folds cross-validator.
        groups (Optional[np.array]): Group labels for the samples used while splitting the dataset into train/test set.
        split_type (str): Type of the data split. Used for reporting information.
        n_models_for_test (int): Number of models to train for testing.
        fast_dev_run (bool): Whether to run a fast development run.
        n_trials (int): Number of trials for hyperparameter optimization.
        logger_save_dir (str): Directory to save logs.
        logger_name (str): Name of the logger.
        active_label (str): The active label column.
        num_boost_round (int): Maximum number of epochs.
        study_filename (Optional[str]): File name to save/load the Optuna study.
        force_study (bool): Whether to force the study optimization even if the study file exists.

    Returns:
        dict: A dictionary containing reports from the CV and test.
    """
    # Set the verbosity of Optuna
    optuna.logging.set_verbosity(optuna.logging.WARNING)
    
    # Create an Optuna study object
    sampler = TPESampler(seed=42)
    study = optuna.create_study(direction='minimize', sampler=sampler)

    study_loaded = False
    if study_filename and not force_study:
        if os.path.exists(study_filename):
            study = joblib.load(study_filename)
            study_loaded = True
            logging.info(f'Loaded study from {study_filename}')

    if not study_loaded or force_study:
        study.optimize(
            lambda trial: xgboost_model_objective(
                trial=trial,
                protein2embedding=protein2embedding,
                cell2embedding=cell2embedding,
                smiles2fp=smiles2fp,
                train_val_df=train_val_df,
                kf=kf,
                groups=groups,
                active_label=active_label,
                num_boost_round=num_boost_round,
            ),
            n_trials=n_trials,
        )
        if study_filename:
            joblib.dump(study, study_filename)

    cv_report = pd.DataFrame(study.best_trial.user_attrs['report'])
    hparam_report = pd.DataFrame([study.best_params])

    # Train the best CV models and store their models by running the objective
    if model_name:
        xgboost_model_objective(
            trial=study.best_trial,
            protein2embedding=protein2embedding,
            cell2embedding=cell2embedding,
            smiles2fp=smiles2fp,
            train_val_df=train_val_df,
            kf=kf,
            groups=groups,
            active_label=active_label,
            num_boost_round=num_boost_round,
            model_name=f'{model_name}_cv_model_{split_type}',
        )

    # Retrain N models with the best hyperparameters (measure model uncertainty)
    best_models = []
    test_report = []
    test_preds = []
    for i in range(n_models_for_test):
        logging.info(f'Training best model {i + 1}/{n_models_for_test}')
        model, preds, metrics = train_and_evaluate_xgboost(
            protein2embedding=protein2embedding,
            cell2embedding=cell2embedding,
            smiles2fp=smiles2fp,
            train_df=train_val_df,
            val_df=test_df,
            params=study.best_trial.user_attrs['params'],
            active_label=active_label,
            num_boost_round=num_boost_round,
            shuffle_train_data=True,
        )
        metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
        metrics['model_type'] = 'XGBoost'
        metrics['test_model_id'] = i
        metrics.update(get_dataframe_stats(
            train_val_df,
            test_df=test_df,
            active_label=active_label,
        ))
        test_report.append(metrics.copy())
        test_preds.append(torch.tensor(preds['val_pred']))
        best_models.append(model)

        # Save the trained model
        if model_name:
            model_filename = f'{model_name}_best_model_{split_type}_n{i}-test_acc={metrics["test_acc"]:.2f}-test_roc_auc={metrics["test_roc_auc"]:.3f}.json'
            model.save_model(model_filename)
            logging.info(f'Best XGBoost model saved to: {model_filename}')
    test_report = pd.DataFrame(test_report)

    # Get the majority vote for the test predictions
    majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
    majority_vote_metrics.update(get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label))
    majority_vote_report = pd.DataFrame([majority_vote_metrics])
    majority_vote_report['model_type'] = 'XGBoost'

    # Add a column with the split_type to all reports
    for report in [cv_report, hparam_report, test_report, majority_vote_report]:
        report['split_type'] = split_type

    # Return the reports
    return {
        'cv_report': cv_report,
        'hparam_report': hparam_report,
        'test_report': test_report,
        'majority_vote_report' :majority_vote_report,
    }