PROTAC-Degradation-Predictor / src /plot_experiment_results.py
ribesstefano's picture
Updated README
f3d4b52
raw
history blame
16.8 kB
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF']
def plot_training_curves(df, split_type, stage='test', multimodels=False, groupby='model_id'):
Stage = 'Test' if stage == 'test' else 'Validation'
# Clean the data
df = df.dropna(how='all', axis=1)
# Convert all columns to numeric, setting errors='coerce' to handle non-numeric data
df = df.apply(pd.to_numeric, errors='coerce')
# Group by 'epoch' and aggregate by mean
if multimodels:
epoch_data = df.groupby([groupby, 'epoch']).mean().reset_index()
else:
epoch_data = df.groupby('epoch').mean().reset_index()
fig, ax = plt.subplots(3, 1, figsize=(10, 15))
# Plot training loss
# ax[0].plot(epoch_data.index, epoch_data['train_loss_epoch'], label='Training Loss')
# ax[0].plot(epoch_data.index, epoch_data[f'{stage}_loss'], label=f'{Stage} Loss', linestyle='--')
sns.lineplot(data=epoch_data, x='epoch', y='train_loss_epoch', ax=ax[0], label='Training Loss')
sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_loss', ax=ax[0], label=f'{Stage} Loss', linestyle='--')
ax[0].set_ylabel('Loss')
ax[0].legend(loc='lower right')
ax[0].grid(axis='both', alpha=0.5)
# Plot training accuracy
# ax[1].plot(epoch_data.index, epoch_data['train_acc_epoch'], label='Training Accuracy')
# ax[1].plot(epoch_data.index, epoch_data[f'{stage}_acc'], label=f'{Stage} Accuracy', linestyle='--')
sns.lineplot(data=epoch_data, x='epoch', y='train_acc_epoch', ax=ax[1], label='Training Accuracy')
sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_acc', ax=ax[1], label=f'{Stage} Accuracy', linestyle='--')
ax[1].set_ylabel('Accuracy')
ax[1].legend(loc='lower right')
ax[1].grid(axis='both', alpha=0.5)
# Set limit to y-axis
ax[1].set_ylim(0, 1.0)
# Set y-axis to percentage
ax[1].yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0))
# Plot training ROC-AUC
# ax[2].plot(epoch_data.index, epoch_data['train_roc_auc_epoch'], label='Training ROC-AUC')
# ax[2].plot(epoch_data.index, epoch_data[f'{stage}_roc_auc'], label=f'{Stage} ROC-AUC', linestyle='--')
sns.lineplot(data=epoch_data, x='epoch', y='train_roc_auc_epoch', ax=ax[2], label='Training ROC-AUC')
sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_roc_auc', ax=ax[2], label=f'{Stage} ROC-AUC', linestyle='--')
ax[2].set_ylabel('ROC-AUC')
ax[2].legend(loc='lower right')
ax[2].grid(axis='both', alpha=0.5)
# Set limit to y-axis
ax[2].set_ylim(0, 1.0)
# Set x-axis label
ax[2].set_xlabel('Epoch')
plt.tight_layout()
plt.savefig(f'plots/training_metrics_{split_type}.pdf', bbox_inches='tight')
def plot_performance_metrics(df_cv, df_test, title=None):
# Extract and prepare CV data
cv_data = df_cv[['model_type', 'fold', 'val_acc', 'val_roc_auc', 'test_acc', 'test_roc_auc', 'split_type']]
cv_data = cv_data.melt(id_vars=['model_type', 'fold', 'split_type'], var_name='Metric', value_name='Score')
cv_data['Metric'] = cv_data['Metric'].replace({
'val_acc': 'Validation Accuracy',
'val_roc_auc': 'Validation ROC AUC',
'test_acc': 'Test Accuracy',
'test_roc_auc': 'Test ROC AUC'
})
cv_data['Stage'] = cv_data['Metric'].apply(lambda x: 'Validation' if 'Val' in x else 'Test')
# Remove test data from CV data
cv_data = cv_data[cv_data['Stage'] == 'Validation']
# Extract and prepare test data
test_data = df_test[['model_type', 'test_acc', 'test_roc_auc', 'split_type']]
test_data = test_data.melt(id_vars=['model_type', 'split_type'], var_name='Metric', value_name='Score')
test_data['Metric'] = test_data['Metric'].replace({
'test_acc': 'Test Accuracy',
'test_roc_auc': 'Test ROC AUC'
})
test_data['Stage'] = 'Test'
# Combine CV and test data
combined_data = pd.concat([cv_data, test_data], ignore_index=True)
# Rename 'split_type' values according to a predefined map for clarity
group2name = {
'random': 'Standard Split',
'uniprot': 'Target Split',
'tanimoto': 'Similarity Split',
}
combined_data['Split Type'] = combined_data['split_type'].map(group2name)
# Add dummy model data
dummy_val_acc = []
dummy_test_acc = []
for i, group in enumerate(group2name.keys()):
# Get the majority class in group_df
group_df = df_cv[df_cv['split_type'] == group]
major_col = 'inactive' if group_df['val_inactive_perc'].mean() > 0.5 else 'active'
dummy_val_acc.append(group_df[f'val_{major_col}_perc'].mean())
group_df = df_test[df_test['split_type'] == group]
major_col = 'inactive' if group_df['test_inactive_perc'].mean() > 0.5 else 'active'
dummy_test_acc.append(group_df[f'test_{major_col}_perc'].mean())
dummy_scores = []
metrics = ['Validation Accuracy', 'Validation ROC AUC', 'Test Accuracy', 'Test ROC AUC']
for i in range(len(dummy_val_acc)):
for metric, score in zip(metrics, [dummy_val_acc[i], 0.5, dummy_test_acc[i], 0.5]):
dummy_scores.append({
'Experiment': i,
'Metric': metric,
'Score': score,
'Split Type': 'Dummy model',
})
dummy_model = pd.DataFrame(dummy_scores)
combined_data = pd.concat([combined_data, dummy_model], ignore_index=True)
# Plotting
plt.figure(figsize=(12, 6))
sns.barplot(
data=combined_data,
x='Metric',
y='Score',
hue='Split Type',
errorbar=('sd', 1),
palette=palette)
plt.title('')
plt.ylabel('')
plt.xlabel('')
plt.ylim(0, 1.0) # Assuming scores are normalized between 0 and 1
plt.grid(axis='y', alpha=0.5, linewidth=0.5)
# Make the y-axis as percentage
plt.gca().yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0))
# Plot the legend below the x-axis, outside the plot, and divided in two columns
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.08), ncol=4)
# For each bar, add the rotated value (as percentage), inside the bar
for i, p in enumerate(plt.gca().patches):
# TODO: For some reasons, there are 4 additional rectangles being
# plotted... I suspect it's because the dummy_df doesn't have the same
# shape as the df containing all the evaluation data...
if p.get_height() < 0.01:
continue
if i % 2 == 0:
value = f'{p.get_height():.1%}'
else:
value = f'{p.get_height():.3f}'
print(f'Plotting value: {p.get_height()} -> {value}')
x = p.get_x() + p.get_width() / 2
y = 0.4 # p.get_height() - p.get_height() / 2
plt.annotate(value, (x, y), ha='center', va='center', color='black', fontsize=10, rotation=90, alpha=0.8)
plt.savefig(f'plots/{title}.pdf', bbox_inches='tight')
def plot_ablation_study(report):
# Define the ablation study combinations
ablation_study_combinations = [
'disabled smiles',
'disabled poi',
'disabled e3',
'disabled cell',
'disabled poi e3',
'disabled poi e3 smiles',
'disabled poi e3 cell',
]
for group in report['split_type'].unique():
baseline = report[report['disabled_embeddings'].isna()].copy()
baseline = baseline[baseline['split_type'] == group]
baseline['disabled_embeddings'] = 'all embeddings enabled'
# metrics_to_show = ['val_acc', 'test_acc']
metrics_to_show = ['test_acc']
# baseline = baseline.melt(id_vars=['fold', 'disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score')
baseline = baseline.melt(id_vars=['disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score')
print('baseline:\n', baseline)
ablation_dfs = []
for disabled_embeddings in ablation_study_combinations:
tmp = report[report['disabled_embeddings'] == disabled_embeddings].copy()
tmp = tmp[tmp['split_type'] == group]
# tmp = tmp.melt(id_vars=['fold', 'disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score')
tmp = tmp.melt(id_vars=['disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score')
ablation_dfs.append(tmp)
ablation_df = pd.concat(ablation_dfs)
print('ablation_df:\n', ablation_df)
# dummy_val_df = pd.DataFrame()
# tmp = report[report['split_type'] == group]
# dummy_val_df['score'] = tmp[['val_active_perc', 'val_inactive_perc']].max(axis=1)
# dummy_val_df['metric'] = 'val_acc'
# dummy_val_df['disabled_embeddings'] = 'dummy'
dummy_test_df = pd.DataFrame()
tmp = report[report['split_type'] == group]
dummy_test_df['score'] = tmp[['test_active_perc', 'test_inactive_perc']].max(axis=1)
dummy_test_df['metric'] = 'test_acc'
dummy_test_df['disabled_embeddings'] = 'dummy'
# dummy_df = pd.concat([dummy_val_df, dummy_test_df])
dummy_df = dummy_test_df
final_df = pd.concat([dummy_df, baseline, ablation_df])
final_df['metric'] = final_df['metric'].map({
'val_acc': 'Validation Accuracy',
'test_acc': 'Test Accuracy',
'val_roc_auc': 'Val ROC-AUC',
'test_roc_auc': 'Test ROC-AUC',
})
final_df['disabled_embeddings'] = final_df['disabled_embeddings'].map({
'all embeddings enabled': 'All embeddings enabled',
'dummy': 'Dummy model',
'disabled smiles': 'Disabled compound information',
'disabled e3': 'Disabled E3 information',
'disabled poi': 'Disabled target information',
'disabled cell': 'Disabled cell information',
'disabled poi e3': 'Disabled E3 and target info',
'disabled poi e3 smiles': 'Disabled compound, E3, and target info\n(only cell information left)',
'disabled poi e3 cell': 'Disabled cell, E3, and target info\n(only compound information left)',
})
# Print final_df to latex
tmp = final_df.groupby(['disabled_embeddings', 'metric']).mean().round(3)
# Remove fold column to tmp
tmp = tmp.reset_index() #.drop('fold', axis=1)
print('DF to plot:\n', tmp.to_markdown(index=False))
# fig, ax = plt.subplots(figsize=(5, 5))
fig, ax = plt.subplots()
sns.barplot(data=final_df,
y='disabled_embeddings',
x='score',
hue='metric',
ax=ax,
errorbar=('sd', 1),
palette=sns.color_palette(palette, len(palette)),
saturation=1,
)
# ax.set_title(f'{group.replace("random", "standard")} CV split')
ax.grid(axis='x', alpha=0.5)
ax.tick_params(axis='y', rotation=0)
ax.set_xlim(0, 1.0)
ax.xaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0))
ax.set_ylabel('')
ax.set_xlabel('')
# Set the legend outside the plot and below
# ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.08), ncol=2)
# Set the legend in the upper right corner
ax.legend(loc='upper right')
# For each bar, add the rotated value (as percentage), inside the bar
for i, p in enumerate(plt.gca().patches):
# TODO: For some reasons, there is an additional bar being added at
# the end of the plot... it's not in the dataframe
if i == len(plt.gca().patches) - 1:
continue
value = '{:.1f}%'.format(100 * p.get_width())
y = p.get_y() + p.get_height() / 2
x = 0.4 # p.get_height() - p.get_height() / 2
plt.annotate(value, (x, y), ha='center', va='center', color='black', fontsize=10, alpha=0.8)
plt.savefig(f'plots/ablation_study_{group}.pdf', bbox_inches='tight')
def plot_majority_voting_performance(df):
# cv_models,test_acc,test_roc_auc,split_type
# Melt the dataframe
df = df.melt(id_vars=['cv_models', 'test_acc', 'test_roc_auc', 'split_type'], var_name='Metric', value_name='Score')
print(df)
def main():
active_col = 'Active (Dmax 0.6, pDC50 6.0)'
test_split = 0.1
n_models_for_test = 3
cv_n_folds = 5
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
report_base_name = f'{active_name}_test_split_{test_split}'
# Load the data
reports = {
'cv_train': pd.concat([
pd.read_csv(f'reports/cv_report_{report_base_name}_random.csv'),
pd.read_csv(f'reports/cv_report_{report_base_name}_uniprot.csv'),
pd.read_csv(f'reports/cv_report_{report_base_name}_tanimoto.csv'),
]),
'test': pd.concat([
pd.read_csv(f'reports/test_report_{report_base_name}_random.csv'),
pd.read_csv(f'reports/test_report_{report_base_name}_uniprot.csv'),
pd.read_csv(f'reports/test_report_{report_base_name}_tanimoto.csv'),
]),
'ablation': pd.concat([
pd.read_csv(f'reports/ablation_report_{report_base_name}_random.csv'),
pd.read_csv(f'reports/ablation_report_{report_base_name}_uniprot.csv'),
pd.read_csv(f'reports/ablation_report_{report_base_name}_tanimoto.csv'),
]),
'hparam': pd.concat([
pd.read_csv(f'reports/hparam_report_{report_base_name}_random.csv'),
pd.read_csv(f'reports/hparam_report_{report_base_name}_uniprot.csv'),
pd.read_csv(f'reports/hparam_report_{report_base_name}_tanimoto.csv'),
]),
'majority_vote': pd.concat([
pd.read_csv(f'reports/majority_vote_report_{report_base_name}_random.csv'),
pd.read_csv(f'reports/majority_vote_report_{report_base_name}_uniprot.csv'),
pd.read_csv(f'reports/majority_vote_report_{report_base_name}_tanimoto.csv'),
]),
}
for split_type in ['random', 'tanimoto', 'uniprot']:
split_metrics = []
for i in range(n_models_for_test):
logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
metrics['model_id'] = i
# Rename 'val_' columns to 'test_' columns
metrics = metrics.rename(columns={'val_loss': 'test_loss', 'val_acc': 'test_acc', 'val_roc_auc': 'test_roc_auc'})
# plot_training_curves(metrics, f'{split_type}_best_model_n{i}')
split_metrics.append(metrics)
plot_training_curves(pd.concat(split_metrics), f'{split_type}_best_model', multimodels=True)
split_metrics_cv = []
for i in range(cv_n_folds):
# logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
logs_dir = f'logs_{report_base_name}_{split_type}_{split_type}_cv_model_fold{i}'
metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
metrics['fold'] = i
# plot_training_curves(metrics, f'{split_type}_cv_model_fold{i}', stage='val')
split_metrics_cv.append(metrics)
plot_training_curves(pd.concat(split_metrics_cv), f'{split_type}_cv_model', stage='val', multimodels=True, groupby='fold')
plot_performance_metrics(
reports['cv_train'],
reports['test'],
title=f'mean_performance-best_models_as_test',
)
plot_performance_metrics(
reports['cv_train'],
reports['cv_train'],
title=f'mean_performance-cv_models_as_test',
)
plot_performance_metrics(
reports['cv_train'],
reports['majority_vote'][reports['majority_vote']['cv_models'].isna()],
title=f'majority_vote_performance-best_models_as_test',
)
plot_performance_metrics(
reports['cv_train'],
reports['majority_vote'][reports['majority_vote']['cv_models'] == True],
title=f'majority_vote_performance-cv_models_as_test',
)
# plot_majority_voting_performance(reports['majority_vote'])
reports['test']['disabled_embeddings'] = pd.NA
plot_ablation_study(pd.concat([
reports['ablation'],
reports['test'],
]))
# # Plot hyperparameter optimization results to markdown
# print(reports['hparam'][['split_type', 'hidden_dim', 'learning_rate', 'dropout', 'use_smote', 'smote_k_neighbors']].to_markdown(index=False))
if __name__ == '__main__':
main()