File size: 18,364 Bytes
74a86c6
 
 
 
 
 
 
 
 
 
 
 
 
 
4e1d3f6
bda3015
 
74a86c6
 
 
 
 
 
 
4e1d3f6
 
 
 
74a86c6
 
 
 
4e1d3f6
 
 
 
 
74a86c6
 
 
 
 
4e1d3f6
 
 
 
74a86c6
 
 
fda7af7
 
 
 
74a86c6
 
4e1d3f6
 
 
 
74a86c6
 
 
fda7af7
 
74a86c6
 
 
 
fda7af7
74a86c6
 
fda7af7
74a86c6
 
aa57971
 
 
 
74a86c6
 
 
 
 
 
 
 
b7582e0
 
74a86c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fda7af7
 
 
 
 
 
 
74a86c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fda7af7
74a86c6
fda7af7
74a86c6
 
 
 
 
 
 
 
 
aa57971
fda7af7
 
 
 
 
 
4e1d3f6
fda7af7
 
 
 
 
 
 
 
 
 
 
 
 
b7582e0
fda7af7
 
 
 
 
 
 
 
 
 
b7582e0
 
fda7af7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e1d3f6
fda7af7
 
 
 
 
 
 
 
 
b7582e0
 
fda7af7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa57971
fda7af7
 
bda3015
 
 
 
 
 
 
74a86c6
 
 
 
bda3015
74a86c6
 
aa57971
74a86c6
 
aa57971
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74a86c6
aa57971
 
b7582e0
74a86c6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np


palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF']


def plot_training_curves(df, split_type, stage='test', multimodels=False, groupby='model_id'):
    Stage = 'Test' if stage == 'test' else 'Validation'

    # Clean the data
    df = df.dropna(how='all', axis=1)

    # Convert all columns to numeric, setting errors='coerce' to handle non-numeric data
    df = df.apply(pd.to_numeric, errors='coerce')

    # Group by 'epoch' and aggregate by mean
    if multimodels:
        epoch_data = df.groupby([groupby, 'epoch']).mean().reset_index()
    else:
        epoch_data = df.groupby('epoch').mean().reset_index()

    fig, ax = plt.subplots(3, 1, figsize=(10, 15))

    # Plot training loss
    # ax[0].plot(epoch_data.index, epoch_data['train_loss_epoch'], label='Training Loss')
    # ax[0].plot(epoch_data.index, epoch_data[f'{stage}_loss'], label=f'{Stage} Loss', linestyle='--')
    sns.lineplot(data=epoch_data, x='epoch', y='train_loss_epoch', ax=ax[0], label='Training Loss')
    sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_loss', ax=ax[0], label=f'{Stage} Loss', linestyle='--')

    ax[0].set_ylabel('Loss')
    ax[0].legend(loc='lower right')
    ax[0].grid(axis='both', alpha=0.5)

    # Plot training accuracy
    # ax[1].plot(epoch_data.index, epoch_data['train_acc_epoch'], label='Training Accuracy')
    # ax[1].plot(epoch_data.index, epoch_data[f'{stage}_acc'], label=f'{Stage} Accuracy', linestyle='--')
    sns.lineplot(data=epoch_data, x='epoch', y='train_acc_epoch', ax=ax[1], label='Training Accuracy')
    sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_acc', ax=ax[1], label=f'{Stage} Accuracy', linestyle='--')
    ax[1].set_ylabel('Accuracy')
    ax[1].legend(loc='lower right')
    ax[1].grid(axis='both', alpha=0.5)
    # Set limit to y-axis
    ax[1].set_ylim(0, 1.0)
    # Set y-axis to percentage
    ax[1].yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0))

    # Plot training ROC-AUC
    # ax[2].plot(epoch_data.index, epoch_data['train_roc_auc_epoch'], label='Training ROC-AUC')
    # ax[2].plot(epoch_data.index, epoch_data[f'{stage}_roc_auc'], label=f'{Stage} ROC-AUC', linestyle='--')
    sns.lineplot(data=epoch_data, x='epoch', y='train_roc_auc_epoch', ax=ax[2], label='Training ROC-AUC')
    sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_roc_auc', ax=ax[2], label=f'{Stage} ROC-AUC', linestyle='--')
    ax[2].set_ylabel('ROC-AUC')
    ax[2].legend(loc='lower right')
    ax[2].grid(axis='both', alpha=0.5)
    # Set limit to y-axis
    ax[2].set_ylim(0, 1.0)
    # Set x-axis label
    ax[2].set_xlabel('Epoch')

    plt.tight_layout()
    plt.savefig(f'plots/training_metrics_{split_type}.pdf', bbox_inches='tight')
    

def plot_performance_metrics(df_cv, df_test, title=None):

    # Extract and prepare CV data
    cols = ['model_type', 'fold', 'val_acc', 'val_roc_auc', 'split_type']
    if 'test_acc' in df_cv.columns:
        cols.extend(['test_acc', 'test_roc_auc'])
    cv_data = df_cv[cols]
    cv_data = cv_data.melt(id_vars=['model_type', 'fold', 'split_type'], var_name='Metric', value_name='Score')
    cv_data['Metric'] = cv_data['Metric'].replace({
        'val_acc': 'Validation Accuracy',
        'val_roc_auc': 'Validation ROC AUC',
        'test_acc': 'Test Accuracy',
        'test_roc_auc': 'Test ROC AUC'
    })
    cv_data['Stage'] = cv_data['Metric'].apply(lambda x: 'Validation' if 'Val' in x else 'Test')
    # Remove test data from CV data
    cv_data = cv_data[cv_data['Stage'] == 'Validation']

    # Extract and prepare test data
    test_data = df_test[['model_type', 'test_acc', 'test_roc_auc', 'split_type']]
    test_data = test_data.melt(id_vars=['model_type', 'split_type'], var_name='Metric', value_name='Score')
    test_data['Metric'] = test_data['Metric'].replace({
        'test_acc': 'Test Accuracy',
        'test_roc_auc': 'Test ROC AUC'
    })
    test_data['Stage'] = 'Test'

    # Combine CV and test data
    combined_data = pd.concat([cv_data, test_data], ignore_index=True)

    # Rename 'split_type' values according to a predefined map for clarity
    group2name = {
        'random': 'Standard Split',
        'uniprot': 'Target Split',
        'tanimoto': 'Similarity Split',
    }
    combined_data['Split Type'] = combined_data['split_type'].map(group2name)

    # Add dummy model data
    dummy_val_acc = []
    dummy_test_acc = []
    for i, group in enumerate(group2name.keys()):
        # Get the majority class in group_df
        group_df = df_cv[df_cv['split_type'] == group]
        major_col = 'inactive' if group_df['val_inactive_perc'].mean() > 0.5 else 'active'
        dummy_val_acc.append(group_df[f'val_{major_col}_perc'].mean())

        group_df = df_test[df_test['split_type'] == group]
        major_col = 'inactive' if group_df['test_inactive_perc'].mean() > 0.5 else 'active'
        dummy_test_acc.append(group_df[f'test_{major_col}_perc'].mean())

    dummy_scores = []
    metrics = ['Validation Accuracy', 'Validation ROC AUC', 'Test Accuracy', 'Test ROC AUC']
    for i in range(len(dummy_val_acc)):
        for metric, score in zip(metrics, [dummy_val_acc[i], 0.5, dummy_test_acc[i], 0.5]):
            dummy_scores.append({
                'Experiment': i,
                'Metric': metric,
                'Score': score,
                'Split Type': 'Dummy model',
            })
    dummy_model = pd.DataFrame(dummy_scores)
    combined_data = pd.concat([combined_data, dummy_model], ignore_index=True)

    # Plotting
    plt.figure(figsize=(12, 6))
    sns.barplot(
        data=combined_data,
        x='Metric',
        y='Score',
        hue='Split Type',
        errorbar=('sd', 1),
        palette=palette)
    plt.title('')
    plt.ylabel('')
    plt.xlabel('')
    plt.ylim(0, 1.0)  # Assuming scores are normalized between 0 and 1
    plt.grid(axis='y', alpha=0.5, linewidth=0.5)

    # Make the y-axis as percentage
    plt.gca().yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0))
    # Plot the legend below the x-axis, outside the plot, and divided in two columns
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.08), ncol=4)

    # For each bar, add the rotated value (as percentage), inside the bar
    for i, p in enumerate(plt.gca().patches):
        # TODO: For some reasons, there are 4 additional rectangles being
        # plotted... I suspect it's because the dummy_df doesn't have the same
        # shape as the df containing all the evaluation data...
        if p.get_height() < 0.01:
            continue
        if i % 2 == 0:
            value = f'{p.get_height():.1%}'
        else:
            value = f'{p.get_height():.3f}'
        
        print(f'Plotting value: {p.get_height()} -> {value}')
        x = p.get_x() + p.get_width() / 2
        y = 0.4 # p.get_height() - p.get_height() / 2
        plt.annotate(value, (x, y), ha='center', va='center', color='black', fontsize=10, rotation=90, alpha=0.8)

    plt.savefig(f'plots/{title}.pdf', bbox_inches='tight')


def plot_ablation_study(report, title=''):
    # Define the ablation study combinations
    ablation_study_combinations = [
        'disabled smiles',
        'disabled poi',
        'disabled e3',
        'disabled cell',
        'disabled poi e3',
        'disabled poi e3 smiles',
        'disabled poi e3 cell',
    ]

    for group in report['split_type'].unique():    
        baseline = report[report['disabled_embeddings'].isna()].copy()
        baseline = baseline[baseline['split_type'] == group]
        baseline['disabled_embeddings'] = 'all embeddings enabled'
        # metrics_to_show = ['val_acc', 'test_acc']
        metrics_to_show = ['test_acc']
        # baseline = baseline.melt(id_vars=['fold', 'disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score')
        baseline = baseline.melt(id_vars=['disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score')

        print('baseline:\n', baseline)

        ablation_dfs = []
        for disabled_embeddings in ablation_study_combinations:
            tmp = report[report['disabled_embeddings'] == disabled_embeddings].copy()
            tmp = tmp[tmp['split_type'] == group]
            # tmp = tmp.melt(id_vars=['fold', 'disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score')
            tmp = tmp.melt(id_vars=['disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score')
            ablation_dfs.append(tmp)
        ablation_df = pd.concat(ablation_dfs)

        print('ablation_df:\n', ablation_df)

        # dummy_val_df = pd.DataFrame()
        # tmp = report[report['split_type'] == group]
        # dummy_val_df['score'] = tmp[['val_active_perc', 'val_inactive_perc']].max(axis=1)
        # dummy_val_df['metric'] = 'val_acc'
        # dummy_val_df['disabled_embeddings'] = 'dummy'

        dummy_test_df = pd.DataFrame()
        tmp = report[report['split_type'] == group]
        dummy_test_df['score'] = tmp[['test_active_perc', 'test_inactive_perc']].max(axis=1)
        dummy_test_df['metric'] = 'test_acc'
        dummy_test_df['disabled_embeddings'] = 'dummy'

        # dummy_df = pd.concat([dummy_val_df, dummy_test_df])
        dummy_df = dummy_test_df

        final_df = pd.concat([dummy_df, baseline, ablation_df])

        final_df['metric'] = final_df['metric'].map({
            'val_acc': 'Validation Accuracy',
            'test_acc': 'Test Accuracy',
            'val_roc_auc': 'Val ROC-AUC',
            'test_roc_auc': 'Test ROC-AUC',
        })

        final_df['disabled_embeddings'] = final_df['disabled_embeddings'].map({
            'all embeddings enabled': 'All embeddings enabled',
            'dummy': 'Dummy model',
            'disabled smiles': 'Disabled compound information',
            'disabled e3': 'Disabled E3 information',
            'disabled poi': 'Disabled target information',
            'disabled cell': 'Disabled cell information',
            'disabled poi e3': 'Disabled E3 and target info',
            'disabled poi e3 smiles': 'Disabled compound, E3, and target info\n(only cell information left)',
            'disabled poi e3 cell': 'Disabled cell, E3, and target info\n(only compound information left)',
        })

        # Print final_df to latex
        tmp  = final_df.groupby(['disabled_embeddings', 'metric']).mean().round(3)
        # Remove fold column to tmp
        tmp = tmp.reset_index() #.drop('fold', axis=1)

        print('DF to plot:\n', tmp.to_markdown(index=False))

        # fig, ax = plt.subplots(figsize=(5, 5))
        fig, ax = plt.subplots()

        sns.barplot(data=final_df,
            y='disabled_embeddings',
            x='score',
            hue='metric',
            ax=ax,
            errorbar=('sd', 1),
            palette=sns.color_palette(palette, len(palette)),
            saturation=1,
        )

        # ax.set_title(f'{group.replace("random", "standard")} CV split')
        ax.grid(axis='x', alpha=0.5)
        ax.tick_params(axis='y', rotation=0)
        ax.set_xlim(0, 1.0)
        ax.xaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0))
        ax.set_ylabel('')
        ax.set_xlabel('')
        # Set the legend outside the plot and below
        # ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.08), ncol=2)
        # Set the legend in the upper right corner
        ax.legend(loc='upper right')

        # For each bar, add the rotated value (as percentage), inside the bar
        for i, p in enumerate(plt.gca().patches):
            # TODO: For some reasons, there is an additional bar being added at
            # the end of the plot... it's not in the dataframe
            if i == len(plt.gca().patches) - 1:
                continue
            value = '{:.1f}%'.format(100 * p.get_width())
            y = p.get_y() + p.get_height() / 2
            x = 0.4 # p.get_height() - p.get_height() / 2
            plt.annotate(value, (x, y), ha='center', va='center', color='black', fontsize=10, alpha=0.8)

        plt.savefig(f'plots/{title}{group}.pdf', bbox_inches='tight')


def plot_majority_voting_performance(df):
    # cv_models,test_acc,test_roc_auc,split_type
    # Melt the dataframe
    df = df.melt(id_vars=['cv_models', 'test_acc', 'test_roc_auc', 'split_type'], var_name='Metric', value_name='Score')
    print(df)


def main():
    active_col = 'Active (Dmax 0.6, pDC50 6.0)'
    test_split = 0.1
    n_models_for_test = 3
    cv_n_folds = 5

    active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
    dataset_info = f'{active_name}_test_split_{test_split}'

    # Load the data
    reports = {}
    for experiment in ['', 'xgboost_', 'cellsonehot_', 'aminoacidcnt_']:
        reports[f'{experiment}cv_train'] = pd.concat([
            pd.read_csv(f'reports/{experiment}cv_report_{dataset_info}_standard.csv'),
            pd.read_csv(f'reports/{experiment}cv_report_{dataset_info}_target.csv'),
            pd.read_csv(f'reports/{experiment}cv_report_{dataset_info}_similarity.csv'),
        ])
        reports[f'{experiment}test'] = pd.concat([
            pd.read_csv(f'reports/{experiment}test_report_{dataset_info}_standard.csv'),
            pd.read_csv(f'reports/{experiment}test_report_{dataset_info}_target.csv'),
            pd.read_csv(f'reports/{experiment}test_report_{dataset_info}_similarity.csv'),
        ])
        reports[f'{experiment}hparam'] = pd.concat([
            pd.read_csv(f'reports/{experiment}hparam_report_{dataset_info}_standard.csv'),
            pd.read_csv(f'reports/{experiment}hparam_report_{dataset_info}_target.csv'),
            pd.read_csv(f'reports/{experiment}hparam_report_{dataset_info}_similarity.csv'),
        ])
        reports[f'{experiment}majority_vote'] = pd.concat([
            pd.read_csv(f'reports/{experiment}majority_vote_report_{dataset_info}_standard.csv'),
            pd.read_csv(f'reports/{experiment}majority_vote_report_{dataset_info}_target.csv'),
            pd.read_csv(f'reports/{experiment}majority_vote_report_{dataset_info}_similarity.csv'),
        ])
        if experiment != 'xgboost_':
            reports[f'{experiment}ablation'] = pd.concat([
                pd.read_csv(f'reports/{experiment}ablation_report_{dataset_info}_standard.csv'),
                pd.read_csv(f'reports/{experiment}ablation_report_{dataset_info}_target.csv'),
                pd.read_csv(f'reports/{experiment}ablation_report_{dataset_info}_similarity.csv'),
            ])

    for experiment in ['', 'xgboost_', 'cellsonehot_', 'aminoacidcnt_']:
        print('=' * 80)
        print(f'Experiment: {experiment}')
        print('=' * 80)

        # Plot training curves
        for split_type in ['standard', 'similarity', 'target']:
            # Skip XGBoost: we don't have its training curves
            if experiment != 'xgboost_':
                # Plot training curves for the best models
                split_metrics = []
                for i in range(n_models_for_test):
                    metrics_dir = f'best_model_n{i}_{experiment}{split_type}_{dataset_info}'
                    metrics = pd.read_csv(f'logs/{metrics_dir}/{metrics_dir}/metrics.csv')
                    metrics['model_id'] = i
                    # Rename 'val_' columns to 'test_' columns
                    metrics = metrics.rename(columns={'val_loss': 'test_loss', 'val_acc': 'test_acc', 'val_roc_auc': 'test_roc_auc'})
                    split_metrics.append(metrics)
                plot_training_curves(pd.concat(split_metrics), f'{experiment}{split_type}_best_model', multimodels=True)

                # Plot training curves for the CV models
                split_metrics_cv = []
                for i in range(cv_n_folds):
                    metrics_dir = f'cv_model_{experiment}{split_type}_{dataset_info}_fold{i}'
                    metrics = pd.read_csv(f'logs/{metrics_dir}/{metrics_dir}/metrics.csv')
                    metrics['fold'] = i
                    split_metrics_cv.append(metrics)
                plot_training_curves(pd.concat(split_metrics_cv), f'{experiment}{split_type}_cv_model', stage='val', multimodels=True, groupby='fold')

        if experiment != 'xgboost_':
            # Skip XGBoost: we don't have test data for its CV models
            plot_performance_metrics(
                reports[f'{experiment}cv_train'],
                reports[f'{experiment}cv_train'],
                title=f'{experiment}mean_performance-cv_models_as_test',
            )
            plot_performance_metrics(
                reports[f'{experiment}cv_train'],
                reports[f'{experiment}majority_vote'][reports[f'{experiment}majority_vote']['cv_models'] == True],
                title=f'{experiment}majority_vote_performance-cv_models_as_test',
            )
            # Skip XGBoost: we don't have its ablation study
            reports[f'{experiment}test']['disabled_embeddings'] = pd.NA
            plot_ablation_study(
                    pd.concat([
                    reports[f'{experiment}ablation'],
                    reports[f'{experiment}test'],
                ]),
                title=f'{experiment}ablation_study_',
            )

        plot_performance_metrics(
            reports[f'{experiment}cv_train'],
            reports[f'{experiment}test'],
            title=f'{experiment}mean_performance-best_models_as_test',
        )

        # 
        if experiment == 'xgboost_':
            df_test = reports[f'{experiment}majority_vote']
        else:
            df_test = reports[f'{experiment}majority_vote'][reports[f'{experiment}majority_vote']['cv_models'].isna()]
        plot_performance_metrics(
            reports[f'{experiment}cv_train'],
            df_test,
            title=f'{experiment}majority_vote_performance-best_models_as_test',
        )

        # # Plot hyperparameter optimization results to markdown
        # print(reports['hparam'][['split_type', 'hidden_dim', 'learning_rate', 'dropout', 'use_smote', 'smote_k_neighbors']].to_markdown(index=False))


if __name__ == '__main__':
    main()