|
import os |
|
import sys |
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import pandas as pd |
|
import numpy as np |
|
|
|
|
|
palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF'] |
|
|
|
|
|
def plot_training_curves(df, split_type, stage='test', multimodels=False, groupby='model_id'): |
|
Stage = 'Test' if stage == 'test' else 'Validation' |
|
|
|
|
|
df = df.dropna(how='all', axis=1) |
|
|
|
|
|
df = df.apply(pd.to_numeric, errors='coerce') |
|
|
|
|
|
if multimodels: |
|
epoch_data = df.groupby([groupby, 'epoch']).mean().reset_index() |
|
else: |
|
epoch_data = df.groupby('epoch').mean().reset_index() |
|
|
|
fig, ax = plt.subplots(3, 1, figsize=(10, 15)) |
|
|
|
|
|
|
|
|
|
sns.lineplot(data=epoch_data, x='epoch', y='train_loss_epoch', ax=ax[0], label='Training Loss') |
|
sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_loss', ax=ax[0], label=f'{Stage} Loss', linestyle='--') |
|
|
|
ax[0].set_ylabel('Loss') |
|
ax[0].legend(loc='lower right') |
|
ax[0].grid(axis='both', alpha=0.5) |
|
|
|
|
|
|
|
|
|
sns.lineplot(data=epoch_data, x='epoch', y='train_acc_epoch', ax=ax[1], label='Training Accuracy') |
|
sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_acc', ax=ax[1], label=f'{Stage} Accuracy', linestyle='--') |
|
ax[1].set_ylabel('Accuracy') |
|
ax[1].legend(loc='lower right') |
|
ax[1].grid(axis='both', alpha=0.5) |
|
|
|
ax[1].set_ylim(0, 1.0) |
|
|
|
ax[1].yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0)) |
|
|
|
|
|
|
|
|
|
sns.lineplot(data=epoch_data, x='epoch', y='train_roc_auc_epoch', ax=ax[2], label='Training ROC-AUC') |
|
sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_roc_auc', ax=ax[2], label=f'{Stage} ROC-AUC', linestyle='--') |
|
ax[2].set_ylabel('ROC-AUC') |
|
ax[2].legend(loc='lower right') |
|
ax[2].grid(axis='both', alpha=0.5) |
|
|
|
ax[2].set_ylim(0, 1.0) |
|
|
|
ax[2].set_xlabel('Epoch') |
|
|
|
plt.tight_layout() |
|
plt.savefig(f'plots/training_metrics_{split_type}.pdf', bbox_inches='tight') |
|
|
|
|
|
def plot_performance_metrics(df_cv, df_test, title=None): |
|
|
|
|
|
cv_data = df_cv[['model_type', 'fold', 'val_acc', 'val_roc_auc', 'test_acc', 'test_roc_auc', 'split_type']] |
|
cv_data = cv_data.melt(id_vars=['model_type', 'fold', 'split_type'], var_name='Metric', value_name='Score') |
|
cv_data['Metric'] = cv_data['Metric'].replace({ |
|
'val_acc': 'Validation Accuracy', |
|
'val_roc_auc': 'Validation ROC AUC', |
|
'test_acc': 'Test Accuracy', |
|
'test_roc_auc': 'Test ROC AUC' |
|
}) |
|
cv_data['Stage'] = cv_data['Metric'].apply(lambda x: 'Validation' if 'Val' in x else 'Test') |
|
|
|
cv_data = cv_data[cv_data['Stage'] == 'Validation'] |
|
|
|
|
|
test_data = df_test[['model_type', 'test_acc', 'test_roc_auc', 'split_type']] |
|
test_data = test_data.melt(id_vars=['model_type', 'split_type'], var_name='Metric', value_name='Score') |
|
test_data['Metric'] = test_data['Metric'].replace({ |
|
'test_acc': 'Test Accuracy', |
|
'test_roc_auc': 'Test ROC AUC' |
|
}) |
|
test_data['Stage'] = 'Test' |
|
|
|
|
|
combined_data = pd.concat([cv_data, test_data], ignore_index=True) |
|
|
|
|
|
group2name = { |
|
'random': 'Standard Split', |
|
'uniprot': 'Target Split', |
|
'tanimoto': 'Similarity Split', |
|
} |
|
combined_data['Split Type'] = combined_data['split_type'].map(group2name) |
|
|
|
|
|
dummy_val_acc = [] |
|
dummy_test_acc = [] |
|
for i, group in enumerate(group2name.keys()): |
|
|
|
group_df = df_cv[df_cv['split_type'] == group] |
|
major_col = 'inactive' if group_df['val_inactive_perc'].mean() > 0.5 else 'active' |
|
dummy_val_acc.append(group_df[f'val_{major_col}_perc'].mean()) |
|
|
|
group_df = df_test[df_test['split_type'] == group] |
|
major_col = 'inactive' if group_df['test_inactive_perc'].mean() > 0.5 else 'active' |
|
dummy_test_acc.append(group_df[f'test_{major_col}_perc'].mean()) |
|
|
|
dummy_scores = [] |
|
metrics = ['Validation Accuracy', 'Validation ROC AUC', 'Test Accuracy', 'Test ROC AUC'] |
|
for i in range(len(dummy_val_acc)): |
|
for metric, score in zip(metrics, [dummy_val_acc[i], 0.5, dummy_test_acc[i], 0.5]): |
|
dummy_scores.append({ |
|
'Experiment': i, |
|
'Metric': metric, |
|
'Score': score, |
|
'Split Type': 'Dummy model', |
|
}) |
|
dummy_model = pd.DataFrame(dummy_scores) |
|
combined_data = pd.concat([combined_data, dummy_model], ignore_index=True) |
|
|
|
|
|
plt.figure(figsize=(12, 6)) |
|
sns.barplot( |
|
data=combined_data, |
|
x='Metric', |
|
y='Score', |
|
hue='Split Type', |
|
errorbar=('sd', 1), |
|
palette=palette) |
|
plt.title('') |
|
plt.ylabel('') |
|
plt.xlabel('') |
|
plt.ylim(0, 1.0) |
|
plt.grid(axis='y', alpha=0.5, linewidth=0.5) |
|
|
|
|
|
plt.gca().yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0)) |
|
|
|
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.08), ncol=4) |
|
|
|
|
|
for i, p in enumerate(plt.gca().patches): |
|
|
|
|
|
|
|
if p.get_height() < 0.01: |
|
continue |
|
if i % 2 == 0: |
|
value = f'{p.get_height():.1%}' |
|
else: |
|
value = f'{p.get_height():.3f}' |
|
|
|
print(f'Plotting value: {p.get_height()} -> {value}') |
|
x = p.get_x() + p.get_width() / 2 |
|
y = 0.4 |
|
plt.annotate(value, (x, y), ha='center', va='center', color='black', fontsize=10, rotation=90, alpha=0.8) |
|
|
|
plt.savefig(f'plots/{title}.pdf', bbox_inches='tight') |
|
|
|
|
|
def plot_ablation_study(report): |
|
|
|
ablation_study_combinations = [ |
|
'disabled smiles', |
|
'disabled poi', |
|
'disabled e3', |
|
'disabled cell', |
|
'disabled poi e3', |
|
'disabled poi e3 smiles', |
|
'disabled poi e3 cell', |
|
] |
|
|
|
for group in report['split_type'].unique(): |
|
baseline = report[report['disabled_embeddings'].isna()].copy() |
|
baseline = baseline[baseline['split_type'] == group] |
|
baseline['disabled_embeddings'] = 'all embeddings enabled' |
|
|
|
metrics_to_show = ['test_acc'] |
|
|
|
baseline = baseline.melt(id_vars=['disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score') |
|
|
|
print('baseline:\n', baseline) |
|
|
|
ablation_dfs = [] |
|
for disabled_embeddings in ablation_study_combinations: |
|
tmp = report[report['disabled_embeddings'] == disabled_embeddings].copy() |
|
tmp = tmp[tmp['split_type'] == group] |
|
|
|
tmp = tmp.melt(id_vars=['disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score') |
|
ablation_dfs.append(tmp) |
|
ablation_df = pd.concat(ablation_dfs) |
|
|
|
print('ablation_df:\n', ablation_df) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dummy_test_df = pd.DataFrame() |
|
tmp = report[report['split_type'] == group] |
|
dummy_test_df['score'] = tmp[['test_active_perc', 'test_inactive_perc']].max(axis=1) |
|
dummy_test_df['metric'] = 'test_acc' |
|
dummy_test_df['disabled_embeddings'] = 'dummy' |
|
|
|
|
|
dummy_df = dummy_test_df |
|
|
|
final_df = pd.concat([dummy_df, baseline, ablation_df]) |
|
|
|
final_df['metric'] = final_df['metric'].map({ |
|
'val_acc': 'Validation Accuracy', |
|
'test_acc': 'Test Accuracy', |
|
'val_roc_auc': 'Val ROC-AUC', |
|
'test_roc_auc': 'Test ROC-AUC', |
|
}) |
|
|
|
final_df['disabled_embeddings'] = final_df['disabled_embeddings'].map({ |
|
'all embeddings enabled': 'All embeddings enabled', |
|
'dummy': 'Dummy model', |
|
'disabled smiles': 'Disabled compound information', |
|
'disabled e3': 'Disabled E3 information', |
|
'disabled poi': 'Disabled target information', |
|
'disabled cell': 'Disabled cell information', |
|
'disabled poi e3': 'Disabled E3 and target info', |
|
'disabled poi e3 smiles': 'Disabled compound, E3, and target info\n(only cell information left)', |
|
'disabled poi e3 cell': 'Disabled cell, E3, and target info\n(only compound information left)', |
|
}) |
|
|
|
|
|
tmp = final_df.groupby(['disabled_embeddings', 'metric']).mean().round(3) |
|
|
|
tmp = tmp.reset_index() |
|
|
|
print('DF to plot:\n', tmp.to_markdown(index=False)) |
|
|
|
|
|
fig, ax = plt.subplots() |
|
|
|
sns.barplot(data=final_df, |
|
y='disabled_embeddings', |
|
x='score', |
|
hue='metric', |
|
ax=ax, |
|
errorbar=('sd', 1), |
|
palette=sns.color_palette(palette, len(palette)), |
|
saturation=1, |
|
) |
|
|
|
|
|
ax.grid(axis='x', alpha=0.5) |
|
ax.tick_params(axis='y', rotation=0) |
|
ax.set_xlim(0, 1.0) |
|
ax.xaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0)) |
|
ax.set_ylabel('') |
|
ax.set_xlabel('') |
|
|
|
|
|
|
|
ax.legend(loc='upper right') |
|
|
|
|
|
for i, p in enumerate(plt.gca().patches): |
|
|
|
|
|
if i == len(plt.gca().patches) - 1: |
|
continue |
|
value = '{:.1f}%'.format(100 * p.get_width()) |
|
y = p.get_y() + p.get_height() / 2 |
|
x = 0.4 |
|
plt.annotate(value, (x, y), ha='center', va='center', color='black', fontsize=10, alpha=0.8) |
|
|
|
plt.savefig(f'plots/ablation_study_{group}.pdf', bbox_inches='tight') |
|
|
|
|
|
def plot_majority_voting_performance(df): |
|
|
|
|
|
df = df.melt(id_vars=['cv_models', 'test_acc', 'test_roc_auc', 'split_type'], var_name='Metric', value_name='Score') |
|
print(df) |
|
|
|
|
|
def main(): |
|
active_col = 'Active (Dmax 0.6, pDC50 6.0)' |
|
test_split = 0.1 |
|
n_models_for_test = 3 |
|
cv_n_folds = 5 |
|
|
|
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '') |
|
report_base_name = f'{active_name}_test_split_{test_split}' |
|
|
|
|
|
reports = { |
|
'cv_train': pd.concat([ |
|
pd.read_csv(f'reports/cv_report_{report_base_name}_random.csv'), |
|
pd.read_csv(f'reports/cv_report_{report_base_name}_uniprot.csv'), |
|
pd.read_csv(f'reports/cv_report_{report_base_name}_tanimoto.csv'), |
|
]), |
|
'test': pd.concat([ |
|
pd.read_csv(f'reports/test_report_{report_base_name}_random.csv'), |
|
pd.read_csv(f'reports/test_report_{report_base_name}_uniprot.csv'), |
|
pd.read_csv(f'reports/test_report_{report_base_name}_tanimoto.csv'), |
|
]), |
|
'ablation': pd.concat([ |
|
pd.read_csv(f'reports/ablation_report_{report_base_name}_random.csv'), |
|
pd.read_csv(f'reports/ablation_report_{report_base_name}_uniprot.csv'), |
|
pd.read_csv(f'reports/ablation_report_{report_base_name}_tanimoto.csv'), |
|
]), |
|
'hparam': pd.concat([ |
|
pd.read_csv(f'reports/hparam_report_{report_base_name}_random.csv'), |
|
pd.read_csv(f'reports/hparam_report_{report_base_name}_uniprot.csv'), |
|
pd.read_csv(f'reports/hparam_report_{report_base_name}_tanimoto.csv'), |
|
]), |
|
'majority_vote': pd.concat([ |
|
pd.read_csv(f'reports/majority_vote_report_{report_base_name}_random.csv'), |
|
pd.read_csv(f'reports/majority_vote_report_{report_base_name}_uniprot.csv'), |
|
pd.read_csv(f'reports/majority_vote_report_{report_base_name}_tanimoto.csv'), |
|
]), |
|
} |
|
|
|
for split_type in ['random', 'tanimoto', 'uniprot']: |
|
split_metrics = [] |
|
for i in range(n_models_for_test): |
|
logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}' |
|
metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv') |
|
metrics['model_id'] = i |
|
|
|
metrics = metrics.rename(columns={'val_loss': 'test_loss', 'val_acc': 'test_acc', 'val_roc_auc': 'test_roc_auc'}) |
|
|
|
split_metrics.append(metrics) |
|
plot_training_curves(pd.concat(split_metrics), f'{split_type}_best_model', multimodels=True) |
|
|
|
split_metrics_cv = [] |
|
for i in range(cv_n_folds): |
|
|
|
logs_dir = f'logs_{report_base_name}_{split_type}_{split_type}_cv_model_fold{i}' |
|
metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv') |
|
metrics['fold'] = i |
|
|
|
split_metrics_cv.append(metrics) |
|
plot_training_curves(pd.concat(split_metrics_cv), f'{split_type}_cv_model', stage='val', multimodels=True, groupby='fold') |
|
|
|
plot_performance_metrics( |
|
reports['cv_train'], |
|
reports['test'], |
|
title=f'mean_performance-best_models_as_test', |
|
) |
|
|
|
plot_performance_metrics( |
|
reports['cv_train'], |
|
reports['cv_train'], |
|
title=f'mean_performance-cv_models_as_test', |
|
) |
|
|
|
plot_performance_metrics( |
|
reports['cv_train'], |
|
reports['majority_vote'][reports['majority_vote']['cv_models'].isna()], |
|
title=f'majority_vote_performance-best_models_as_test', |
|
) |
|
|
|
plot_performance_metrics( |
|
reports['cv_train'], |
|
reports['majority_vote'][reports['majority_vote']['cv_models'] == True], |
|
title=f'majority_vote_performance-cv_models_as_test', |
|
) |
|
|
|
|
|
|
|
reports['test']['disabled_embeddings'] = pd.NA |
|
plot_ablation_study(pd.concat([ |
|
reports['ablation'], |
|
reports['test'], |
|
])) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |