import os |
import sys |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
import matplotlib.pyplot as plt |
import seaborn as sns |
import pandas as pd |
import numpy as np |
palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF'] |
def plot_training_curves(df, split_type, stage='test', multimodels=False, groupby='model_id'): |
Stage = 'Test' if stage == 'test' else 'Validation' |
df = df.dropna(how='all', axis=1) |
df = df.apply(pd.to_numeric, errors='coerce') |
if multimodels: |
epoch_data = df.groupby([groupby, 'epoch']).mean().reset_index() |
else: |
epoch_data = df.groupby('epoch').mean().reset_index() |
fig, ax = plt.subplots(3, 1, figsize=(10, 15)) |
sns.lineplot(data=epoch_data, x='epoch', y='train_loss_epoch', ax=ax[0], label='Training Loss') |
sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_loss', ax=ax[0], label=f'{Stage} Loss', linestyle='--') |
ax[0].set_ylabel('Loss') |
ax[0].legend(loc='lower right') |
ax[0].grid(axis='both', alpha=0.5) |
sns.lineplot(data=epoch_data, x='epoch', y='train_acc_epoch', ax=ax[1], label='Training Accuracy') |
sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_acc', ax=ax[1], label=f'{Stage} Accuracy', linestyle='--') |
ax[1].set_ylabel('Accuracy') |
ax[1].legend(loc='lower right') |
ax[1].grid(axis='both', alpha=0.5) |
ax[1].set_ylim(0, 1.0) |
ax[1].yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0)) |
sns.lineplot(data=epoch_data, x='epoch', y='train_roc_auc_epoch', ax=ax[2], label='Training ROC-AUC') |
sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_roc_auc', ax=ax[2], label=f'{Stage} ROC-AUC', linestyle='--') |
ax[2].set_ylabel('ROC-AUC') |
ax[2].legend(loc='lower right') |
ax[2].grid(axis='both', alpha=0.5) |
ax[2].set_ylim(0, 1.0) |
ax[2].set_xlabel('Epoch') |
plt.tight_layout() |
plt.savefig(f'plots/training_metrics_{split_type}.pdf', bbox_inches='tight') |
def plot_performance_metrics(df_cv, df_test, title=None): |
cv_data = df_cv[['model_type', 'fold', 'val_acc', 'val_roc_auc', 'test_acc', 'test_roc_auc', 'split_type']] |
cv_data = cv_data.melt(id_vars=['model_type', 'fold', 'split_type'], var_name='Metric', value_name='Score') |
cv_data['Metric'] = cv_data['Metric'].replace({ |
'val_acc': 'Validation Accuracy', |
'val_roc_auc': 'Validation ROC AUC', |
'test_acc': 'Test Accuracy', |
'test_roc_auc': 'Test ROC AUC' |
}) |
cv_data['Stage'] = cv_data['Metric'].apply(lambda x: 'Validation' if 'Val' in x else 'Test') |
cv_data = cv_data[cv_data['Stage'] == 'Validation'] |
test_data = df_test[['model_type', 'test_acc', 'test_roc_auc', 'split_type']] |
test_data = test_data.melt(id_vars=['model_type', 'split_type'], var_name='Metric', value_name='Score') |
test_data['Metric'] = test_data['Metric'].replace({ |
'test_acc': 'Test Accuracy', |
'test_roc_auc': 'Test ROC AUC' |
}) |
test_data['Stage'] = 'Test' |
combined_data = pd.concat([cv_data, test_data], ignore_index=True) |
group2name = { |
'random': 'Standard Split', |
'uniprot': 'Target Split', |
'tanimoto': 'Similarity Split', |
} |
combined_data['Split Type'] = combined_data['split_type'].map(group2name) |
dummy_val_acc = [] |
dummy_test_acc = [] |
for i, group in enumerate(group2name.keys()): |
group_df = df_cv[df_cv['split_type'] == group] |
major_col = 'inactive' if group_df['val_inactive_perc'].mean() > 0.5 else 'active' |
dummy_val_acc.append(group_df[f'val_{major_col}_perc'].mean()) |
group_df = df_test[df_test['split_type'] == group] |
major_col = 'inactive' if group_df['test_inactive_perc'].mean() > 0.5 else 'active' |
dummy_test_acc.append(group_df[f'test_{major_col}_perc'].mean()) |
dummy_scores = [] |
metrics = ['Validation Accuracy', 'Validation ROC AUC', 'Test Accuracy', 'Test ROC AUC'] |
for i in range(len(dummy_val_acc)): |
for metric, score in zip(metrics, [dummy_val_acc[i], 0.5, dummy_test_acc[i], 0.5]): |
dummy_scores.append({ |
'Experiment': i, |
'Metric': metric, |
'Score': score, |
'Split Type': 'Dummy model', |
}) |
dummy_model = pd.DataFrame(dummy_scores) |
combined_data = pd.concat([combined_data, dummy_model], ignore_index=True) |
plt.figure(figsize=(12, 6)) |
sns.barplot( |
data=combined_data, |
x='Metric', |
y='Score', |
hue='Split Type', |
errorbar=('sd', 1), |
palette=palette) |
plt.title('') |
plt.ylabel('') |
plt.xlabel('') |
plt.ylim(0, 1.0) |
plt.grid(axis='y', alpha=0.5, linewidth=0.5) |
plt.gca().yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0)) |
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.08), ncol=4) |
for i, p in enumerate(plt.gca().patches): |
if p.get_height() < 0.01: |
continue |
if i % 2 == 0: |
value = f'{p.get_height():.1%}' |
else: |
value = f'{p.get_height():.3f}' |
print(f'Plotting value: {p.get_height()} -> {value}') |
x = p.get_x() + p.get_width() / 2 |
y = 0.4 |
plt.annotate(value, (x, y), ha='center', va='center', color='black', fontsize=10, rotation=90, alpha=0.8) |
plt.savefig(f'plots/{title}.pdf', bbox_inches='tight') |
def plot_ablation_study(report): |
ablation_study_combinations = [ |
'disabled smiles', |
'disabled poi', |
'disabled e3', |
'disabled cell', |
'disabled poi e3', |
'disabled poi e3 smiles', |
'disabled poi e3 cell', |
] |
for group in report['split_type'].unique(): |
baseline = report[report['disabled_embeddings'].isna()].copy() |
baseline = baseline[baseline['split_type'] == group] |
baseline['disabled_embeddings'] = 'all embeddings enabled' |
metrics_to_show = ['test_acc'] |
baseline = baseline.melt(id_vars=['disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score') |
print('baseline:\n', baseline) |
ablation_dfs = [] |
for disabled_embeddings in ablation_study_combinations: |
tmp = report[report['disabled_embeddings'] == disabled_embeddings].copy() |
tmp = tmp[tmp['split_type'] == group] |
tmp = tmp.melt(id_vars=['disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score') |
ablation_dfs.append(tmp) |
ablation_df = pd.concat(ablation_dfs) |
print('ablation_df:\n', ablation_df) |
dummy_test_df = pd.DataFrame() |
tmp = report[report['split_type'] == group] |
dummy_test_df['score'] = tmp[['test_active_perc', 'test_inactive_perc']].max(axis=1) |
dummy_test_df['metric'] = 'test_acc' |
dummy_test_df['disabled_embeddings'] = 'dummy' |
dummy_df = dummy_test_df |
final_df = pd.concat([dummy_df, baseline, ablation_df]) |
final_df['metric'] = final_df['metric'].map({ |
'val_acc': 'Validation Accuracy', |
'test_acc': 'Test Accuracy', |
'val_roc_auc': 'Val ROC-AUC', |
'test_roc_auc': 'Test ROC-AUC', |
}) |
final_df['disabled_embeddings'] = final_df['disabled_embeddings'].map({ |
'all embeddings enabled': 'All embeddings enabled', |
'dummy': 'Dummy model', |
'disabled smiles': 'Disabled compound information', |
'disabled e3': 'Disabled E3 information', |
'disabled poi': 'Disabled target information', |
'disabled cell': 'Disabled cell information', |
'disabled poi e3': 'Disabled E3 and target info', |
'disabled poi e3 smiles': 'Disabled compound, E3, and target info\n(only cell information left)', |
'disabled poi e3 cell': 'Disabled cell, E3, and target info\n(only compound information left)', |
}) |
tmp = final_df.groupby(['disabled_embeddings', 'metric']).mean().round(3) |
tmp = tmp.reset_index() |
print('DF to plot:\n', tmp.to_markdown(index=False)) |
fig, ax = plt.subplots() |
sns.barplot(data=final_df, |
y='disabled_embeddings', |
x='score', |
hue='metric', |
ax=ax, |
errorbar=('sd', 1), |
palette=sns.color_palette(palette, len(palette)), |
saturation=1, |
) |
ax.grid(axis='x', alpha=0.5) |
ax.tick_params(axis='y', rotation=0) |
ax.set_xlim(0, 1.0) |
ax.xaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0)) |
ax.set_ylabel('') |
ax.set_xlabel('') |
ax.legend(loc='upper right') |
for i, p in enumerate(plt.gca().patches): |
if i == len(plt.gca().patches) - 1: |
continue |
value = '{:.1f}%'.format(100 * p.get_width()) |
y = p.get_y() + p.get_height() / 2 |
x = 0.4 |
plt.annotate(value, (x, y), ha='center', va='center', color='black', fontsize=10, alpha=0.8) |
plt.savefig(f'plots/ablation_study_{group}.pdf', bbox_inches='tight') |
def plot_majority_voting_performance(df): |
df = df.melt(id_vars=['cv_models', 'test_acc', 'test_roc_auc', 'split_type'], var_name='Metric', value_name='Score') |
print(df) |
def main(): |
active_col = 'Active (Dmax 0.6, pDC50 6.0)' |
test_split = 0.1 |
n_models_for_test = 3 |
cv_n_folds = 5 |
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '') |
report_base_name = f'{active_name}_test_split_{test_split}' |
reports = { |
'cv_train': pd.concat([ |
pd.read_csv(f'reports/cv_report_{report_base_name}_random.csv'), |
pd.read_csv(f'reports/cv_report_{report_base_name}_uniprot.csv'), |
pd.read_csv(f'reports/cv_report_{report_base_name}_tanimoto.csv'), |
]), |
'test': pd.concat([ |
pd.read_csv(f'reports/test_report_{report_base_name}_random.csv'), |
pd.read_csv(f'reports/test_report_{report_base_name}_uniprot.csv'), |
pd.read_csv(f'reports/test_report_{report_base_name}_tanimoto.csv'), |
]), |
'ablation': pd.concat([ |
pd.read_csv(f'reports/ablation_report_{report_base_name}_random.csv'), |
pd.read_csv(f'reports/ablation_report_{report_base_name}_uniprot.csv'), |
pd.read_csv(f'reports/ablation_report_{report_base_name}_tanimoto.csv'), |
]), |
'hparam': pd.concat([ |
pd.read_csv(f'reports/hparam_report_{report_base_name}_random.csv'), |
pd.read_csv(f'reports/hparam_report_{report_base_name}_uniprot.csv'), |
pd.read_csv(f'reports/hparam_report_{report_base_name}_tanimoto.csv'), |
]), |
'majority_vote': pd.concat([ |
pd.read_csv(f'reports/majority_vote_report_{report_base_name}_random.csv'), |
pd.read_csv(f'reports/majority_vote_report_{report_base_name}_uniprot.csv'), |
pd.read_csv(f'reports/majority_vote_report_{report_base_name}_tanimoto.csv'), |
]), |
} |
for split_type in ['random', 'tanimoto', 'uniprot']: |
split_metrics = [] |
for i in range(n_models_for_test): |
logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}' |
metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv') |
metrics['model_id'] = i |
metrics = metrics.rename(columns={'val_loss': 'test_loss', 'val_acc': 'test_acc', 'val_roc_auc': 'test_roc_auc'}) |
split_metrics.append(metrics) |
plot_training_curves(pd.concat(split_metrics), f'{split_type}_best_model', multimodels=True) |
split_metrics_cv = [] |
for i in range(cv_n_folds): |
logs_dir = f'logs_{report_base_name}_{split_type}_{split_type}_cv_model_fold{i}' |
metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv') |
metrics['fold'] = i |
split_metrics_cv.append(metrics) |
plot_training_curves(pd.concat(split_metrics_cv), f'{split_type}_cv_model', stage='val', multimodels=True, groupby='fold') |
plot_performance_metrics( |
reports['cv_train'], |
reports['test'], |
title=f'mean_performance-best_models_as_test', |
) |
plot_performance_metrics( |
reports['cv_train'], |
reports['cv_train'], |
title=f'mean_performance-cv_models_as_test', |
) |
plot_performance_metrics( |
reports['cv_train'], |
reports['majority_vote'][reports['majority_vote']['cv_models'].isna()], |
title=f'majority_vote_performance-best_models_as_test', |
) |
plot_performance_metrics( |
reports['cv_train'], |
reports['majority_vote'][reports['majority_vote']['cv_models'] == True], |
title=f'majority_vote_performance-cv_models_as_test', |
) |
reports['test']['disabled_embeddings'] = pd.NA |
plot_ablation_study(pd.concat([ |
reports['ablation'], |
reports['test'], |
])) |
if __name__ == '__main__': |
main() |