Fill-Mask
Transformers
PyTorch
esm
Inference Endpoints
FusOn-pLM / fuson_plm /utils /visualizing.py
svincoff's picture
adding utility files used throughout FusOn-pLM training and benchmarking
ffaff91
raw
history blame
27.4 kB
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from matplotlib.font_manager import FontProperties
from scipy.stats import entropy
from sklearn.manifold import TSNE
import pickle
import pandas as pd
import os
import numpy as np
from fuson_plm.utils.logging import log_update, find_fuson_plm_directory
def set_font():
# Load and set the font
fuson_plm_dir = find_fuson_plm_directory()
# Paths for regular, bold, italic fonts
regular_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Regular.ttf')
bold_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Bold.ttf')
italic_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Italic.ttf')
bold_italic_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-BoldItalic.ttf')
# Load the font properties
regular_font = FontProperties(fname=regular_font_path)
bold_font = FontProperties(fname=bold_font_path)
italic_font = FontProperties(fname=italic_font_path)
bold_italic_font = FontProperties(fname=bold_italic_font_path)
# Add the fonts to the font manager
fm.fontManager.addfont(regular_font_path)
fm.fontManager.addfont(bold_font_path)
fm.fontManager.addfont(italic_font_path)
fm.fontManager.addfont(bold_italic_font_path)
# Set the font family globally to Ubuntu
plt.rcParams['font.family'] = regular_font.get_name()
# Set the fonts for math text (like for labels) to use the loaded Ubuntu fonts
plt.rcParams['mathtext.fontset'] = 'custom'
plt.rcParams['mathtext.rm'] = regular_font.get_name()
plt.rcParams['mathtext.it'] = f'{italic_font.get_name()}'
plt.rcParams['mathtext.bf'] = f'{bold_font.get_name()}'
global default_color_map
default_color_map = {
'train': '#0072B2',
'val': '#009E73',
'test': '#E69F00'
}
def get_avg_embeddings_for_tsne(train_sequences=None, val_sequences=None, test_sequences=None, embedding_path='fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl'):
if train_sequences is None: train_sequences = []
if val_sequences is None: val_sequences = []
if test_sequences is None: test_sequences = []
embeddings = {}
try:
with open(embedding_path, 'rb') as f:
embeddings = pickle.load(f)
train_embeddings = [v for k, v in embeddings.items() if k in train_sequences]
val_embeddings = [v for k, v in embeddings.items() if k in val_sequences]
test_embeddings = [v for k, v in embeddings.items() if k in test_sequences]
return train_embeddings, val_embeddings, test_embeddings
except:
print("could not open embeddings")
def calculate_aa_composition(sequences):
composition = {}
total_length = sum([len(seq) for seq in sequences])
for seq in sequences:
for aa in seq:
if aa in composition:
composition[aa] += 1
else:
composition[aa] = 1
# Convert counts to relative frequency
for aa in composition:
composition[aa] /= total_length
return composition
def calculate_shannon_entropy(sequence):
"""
Calculate the Shannon entropy for a given sequence.
Args:
sequence (str): A sequence of characters (e.g., amino acids or nucleotides).
Returns:
float: Shannon entropy value.
"""
bases = set(sequence)
counts = [sequence.count(base) for base in bases]
return entropy(counts, base=2)
def visualize_splits_hist(train_lengths=None, val_lengths=None, test_lengths=None, colormap=None, savepath=f'splits/length_distributions.png', axes=None):
"""
Works to plot train, val, test; train, val; or train, test
"""
set_font()
if colormap is None: colormap=default_color_map
log_update('\nMaking histogram of length distributions')
# Get index for test plot
val_plot_index, test_plot_index, total_plots = 1, 2, 3
if val_lengths is None:
val_plot_index = None
test_plot_index-= 1
total_plots-=1
if test_lengths is None:
test_plot_index = None
total_plots-=1
# Create a figure and axes with 1 row and 3 columns
fig_individual, axes_individual = plt.subplots(1, total_plots, figsize=(6*total_plots, 6))
# Set axes list
axes_list = [axes_individual] if axes is None else [axes_individual, axes]
# Unpack the labels and titles
xlabel, ylabel = ['Sequence Length (AA)', 'Frequency']
for cur_axes in axes_list:
# Plot the first histogram
cur_axes[0].hist(train_lengths, bins=20, edgecolor='k',color=colormap['train'])
cur_axes[0].set_xlabel(xlabel)
cur_axes[0].set_ylabel(ylabel)
cur_axes[0].set_title(f'Train Set Length Distribution (n={len(train_lengths)})')
cur_axes[0].grid(True)
cur_axes[0].set_axisbelow(True)
# Plot the second histogram
if not(val_plot_index is None):
cur_axes[val_plot_index].hist(val_lengths, bins=20, edgecolor='k',color=colormap['val'])
cur_axes[val_plot_index].set_xlabel(xlabel)
cur_axes[val_plot_index].set_ylabel(ylabel)
cur_axes[val_plot_index].set_title(f'Validation Set Length Distribution (n={len(val_lengths)})')
cur_axes[val_plot_index].grid(True)
cur_axes[val_plot_index].set_axisbelow(True)
# Plot the third histogram
if not(test_plot_index is None):
cur_axes[test_plot_index].hist(test_lengths, bins=20, edgecolor='k',color=colormap['test'])
cur_axes[test_plot_index].set_xlabel(xlabel)
cur_axes[test_plot_index].set_ylabel(ylabel)
cur_axes[test_plot_index].set_title(f'Test Set Length Distribution (n={len(test_lengths)})')
cur_axes[test_plot_index].grid(True)
cur_axes[test_plot_index].set_axisbelow(True)
# Adjust layout
fig_individual.set_tight_layout(True)
# Save the figure
fig_individual.savefig(savepath)
log_update(f"\tSaved figure to {savepath}")
def visualize_splits_scatter(train_clusters=None, val_clusters=None, test_clusters=None, benchmark_cluster_reps=None, colormap=None, savepath='splits/scatterplot.png', axes=None):
set_font()
if colormap is None: colormap=default_color_map
# Create a figure and axes with 1 row and 3 columns
fig_individual, axes_individual = plt.subplots(figsize=(18, 6))
# Set axes list
axes_list = [axes_individual] if axes is None else [axes_individual, axes]
log_update("\nMaking scatterplot with distribution of cluster sizes across train, test, and val")
# Make grouped versions of these DataFrames for size analysis
train_clustersgb = train_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
if not(val_clusters is None):
val_clustersgb = val_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
if not(test_clusters is None):
test_clustersgb = test_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
# Isolate benchmark-containing clusters so their contribution can be plotted separately
total_test_proteins = sum(test_clustersgb['member count'])
if not(benchmark_cluster_reps is None):
test_clustersgb['benchmark cluster'] = test_clustersgb['representative seq_id'].isin(benchmark_cluster_reps)
benchmark_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']].reset_index(drop=True)
test_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']==False].reset_index(drop=True)
# Convert them to value counts
train_clustersgb = train_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
if not(val_clusters is None):
val_clustersgb = val_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
if not(test_clusters is None):
test_clustersgb = test_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
if not(benchmark_cluster_reps is None):
benchmark_clustersgb = benchmark_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
# Get the percentage of each dataset that's made of each cluster size
train_clustersgb['n_proteins'] = train_clustersgb['cluster size (n_members)']*train_clustersgb['n_clusters'] # proteins per cluster * n clusters = # proteins
train_clustersgb['percent_proteins'] = train_clustersgb['n_proteins']/sum(train_clustersgb['n_proteins'])
if not(val_clusters is None):
val_clustersgb['n_proteins'] = val_clustersgb['cluster size (n_members)']*val_clustersgb['n_clusters']
val_clustersgb['percent_proteins'] = val_clustersgb['n_proteins']/sum(val_clustersgb['n_proteins'])
if not(test_clusters is None):
test_clustersgb['n_proteins'] = test_clustersgb['cluster size (n_members)']*test_clustersgb['n_clusters']
test_clustersgb['percent_proteins'] = test_clustersgb['n_proteins']/total_test_proteins
if not(benchmark_cluster_reps is None):
benchmark_clustersgb['n_proteins'] = benchmark_clustersgb['cluster size (n_members)']*benchmark_clustersgb['n_clusters']
benchmark_clustersgb['percent_proteins'] = benchmark_clustersgb['n_proteins']/total_test_proteins
# Specially mark the benchmark clusters because these can't be reallocated
for ax in axes_list:
ax.plot(train_clustersgb['cluster size (n_members)'],train_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['train'],label='train')
if not(val_clusters is None):
ax.plot(val_clustersgb['cluster size (n_members)'],val_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['val'],label='val')
if not(test_clusters is None):
ax.plot(test_clustersgb['cluster size (n_members)'],test_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['test'],label='test')
if not(benchmark_cluster_reps is None):
ax.plot(benchmark_clustersgb['cluster size (n_members)'],benchmark_clustersgb['percent_proteins'],
marker='o',
linestyle='None',
markerfacecolor=colormap['test'], # fill same as test
markeredgecolor='black', # outline black
markeredgewidth=1.5,
label='benchmark'
)
ax.set(ylabel='Percentage of Proteins in Dataset',xlabel='cluster_size')
ax.legend()
# save the figure
fig_individual.set_tight_layout(True)
fig_individual.savefig(savepath)
log_update(f"\tSaved figure to {savepath}")
def visualize_splits_tsne(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, esm_type="esm2_t33_650M_UR50D", embedding_path="fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl", savepath='splits/tsne_plot.png',axes=None):
set_font()
if colormap is None: colormap=default_color_map
"""
Generate a t-SNE plot of embeddings for train, test, and validation.
"""
log_update('\nMaking t-SNE plot of train, val, and test embeddings')
# Create a figure and axes with 1 row and 3 columns
fig_individual, axes_individual = plt.subplots(figsize=(18, 6))
# Set axes list
axes_list = [axes_individual] if axes is None else [axes_individual, axes]
# Combine the embeddings into one array
train_embeddings, val_embeddings, test_embeddings = get_avg_embeddings_for_tsne(train_sequences=train_sequences,
val_sequences=val_sequences,
test_sequences=test_sequences, embedding_path=embedding_path)
if not(val_embeddings is None) and not(test_embeddings is None):
embeddings = np.concatenate([train_embeddings, val_embeddings, test_embeddings])
labels = ['train'] * len(train_embeddings) + ['val'] * len(val_embeddings) + ['test'] * len(test_embeddings)
if not(val_embeddings is None) and (test_embeddings is None):
embeddings = np.concatenate([train_embeddings, val_embeddings])
labels = ['train'] * len(train_embeddings) + ['val'] * len(val_embeddings)
if (val_embeddings is None) and not(test_embeddings is None):
embeddings = np.concatenate([train_embeddings, test_embeddings])
labels = ['train'] * len(train_embeddings) + ['test'] * len(test_embeddings)
# Perform t-SNE
tsne = TSNE(n_components=2, random_state=42)
tsne_results = tsne.fit_transform(embeddings)
# Convert t-SNE results into a DataFrame
tsne_df = pd.DataFrame(data=tsne_results, columns=['TSNE_1', 'TSNE_2'])
tsne_df['label'] = labels
for ax in axes_list:
# Scatter plot for each set
for label, color in colormap.items():
subset = tsne_df[tsne_df['label'] == label].reset_index(drop=True)
ax.scatter(subset['TSNE_1'], subset['TSNE_2'], c=color, label=label.capitalize(), alpha=0.6)
ax.set_title(f't-SNE of {esm_type} Embeddings')
ax.set_xlabel('t-SNE Dimension 1')
ax.set_ylabel('t-SNE Dimension 2')
ax.legend()
ax.grid(True)
# Save the figure if savepath is provided
fig_individual.set_tight_layout(True)
fig_individual.savefig(savepath)
log_update(f"\tSaved figure to {savepath}")
def visualize_splits_shannon_entropy(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, savepath='splits/shannon_entropy_plot.png',axes=None):
set_font()
"""
Generate Shannon entropy plots for train, validation, and test sets.
"""
# Get index for test plot
val_plot_index, test_plot_index, total_plots = 1, 2, 3
if val_sequences is None:
val_plot_index = None
test_plot_index-= 1
total_plots-=1
if test_sequences is None:
test_plot_index = None
total_plots-=1
if colormap is None: colormap=default_color_map
# Create a figure and axes with 1 row and 3 columns
fig_individual, axes_individual = plt.subplots(1, total_plots, figsize=(6*total_plots, 6))
# Set axes list
axes_list = [axes_individual] if axes is None else [axes_individual, axes]
log_update('\nMaking histogram of Shannon Entropy distributions')
train_entropy = [calculate_shannon_entropy(seq) for seq in train_sequences]
if not(val_plot_index is None):
val_entropy = [calculate_shannon_entropy(seq) for seq in val_sequences]
if not(test_plot_index is None):
test_entropy = [calculate_shannon_entropy(seq) for seq in test_sequences]
for ax in axes_list:
ax[0].hist(train_entropy, bins=20, edgecolor='k', color=colormap['train'])
ax[0].set_title(f'Train Set (n={len(train_entropy)})')
ax[0].set_xlabel('Shannon Entropy')
ax[0].set_ylabel('Frequency')
ax[0].grid(True)
ax[0].set_axisbelow(True)
if not(val_plot_index is None):
ax[val_plot_index].hist(val_entropy, bins=20, edgecolor='k', color=colormap['val'])
ax[val_plot_index].set_title(f'Validation Set (n={len(val_entropy)})')
ax[val_plot_index].set_xlabel('Shannon Entropy')
ax[val_plot_index].grid(True)
ax[val_plot_index].set_axisbelow(True)
if not(test_plot_index is None):
ax[test_plot_index].hist(test_entropy, bins=20, edgecolor='k', color=colormap['test'])
ax[test_plot_index].set_title(f'Test Set (n={len(test_entropy)})')
ax[test_plot_index].set_xlabel('Shannon Entropy')
ax[test_plot_index].grid(True)
ax[test_plot_index].set_axisbelow(True)
fig_individual.set_tight_layout(True)
fig_individual.savefig(savepath)
log_update(f"\tSaved figure to {savepath}")
def visualize_splits_aa_composition(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, savepath='splits/aa_comp.png',axes=None):
set_font()
if colormap is None: colormap=default_color_map
# Create a figure and axes with 1 row and 3 columns
fig_individual, axes_individual = plt.subplots(figsize=(18, 6))
# Set axes list
axes_list = [axes_individual] if axes is None else [axes_individual, axes]
log_update('\nMaking bar plot of AA composition across each set')
train_comp = calculate_aa_composition(train_sequences)
if not(val_sequences is None):
val_comp = calculate_aa_composition(val_sequences)
if not(test_sequences is None):
test_comp = calculate_aa_composition(test_sequences)
# Create DataFrame
if not(val_sequences is None) and not(test_sequences is None):
comp_df = pd.DataFrame([train_comp, val_comp, test_comp], index=['train', 'val', 'test']).T
if not(val_sequences is None) and (test_sequences is None):
comp_df = pd.DataFrame([train_comp, val_comp], index=['train', 'val']).T
if (val_sequences is None) and not(test_sequences is None):
comp_df = pd.DataFrame([train_comp, test_comp], index=['train', 'test']).T
colors = [colormap[col] for col in comp_df.columns]
# Plotting
for ax in axes_list:
comp_df.plot(kind='bar', color=colors, ax=ax)
ax.set_title('Amino Acid Composition Across Datasets')
ax.set_xlabel('Amino Acid')
ax.set_ylabel('Relative Frequency')
fig_individual.set_tight_layout(True)
fig_individual.savefig(savepath)
log_update(f"\tSaved figure to {savepath}")
### Outer methods for visualizing splits
def visualize_splits(train_clusters=None, val_clusters=None, test_clusters=None, benchmark_cluster_reps=None, train_color='#0072B2',val_color='#009E73',test_color='#E69F00',esm_embeddings_path=None, onehot_embeddings_path=None):
colormap = {
'train': train_color,
'val': val_color,
'test': test_color
}
valid_entry = False
# Add columns for plotting
if not(train_clusters is None) and not(val_clusters is None) and not(test_clusters is None):
visualize_train_val_test_splits(train_clusters, val_clusters, test_clusters,benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap)
valid_entry=True
if not(train_clusters is None) and (val_clusters is None) and not(test_clusters is None):
visualize_train_test_splits(train_clusters, test_clusters, benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap)
valid_entry=True
if not(train_clusters is None) and not(val_clusters is None) and (test_clusters is None):
visualize_train_val_splits(train_clusters, val_clusters, benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap)
valid_entry=True
if not(valid_entry): raise Exception("Must pass train and at least one of val or test")
def visualize_train_val_test_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None):
if colormap is None: colormap=default_color_map
# Add length column
train_clusters['member length'] = train_clusters['member seq'].str.len()
val_clusters['member length'] = val_clusters['member seq'].str.len()
test_clusters['member length'] = test_clusters['member seq'].str.len()
# Prepare lengths and seqs for plotting
train_lengths = train_clusters['member length'].tolist()
val_lengths = val_clusters['member length'].tolist()
test_lengths = test_clusters['member length'].tolist()
train_sequences = train_clusters['member seq'].tolist()
val_sequences = val_clusters['member seq'].tolist()
test_sequences = test_clusters['member seq'].tolist()
# Create a combined figure with 3 rows and 3 columns
set_font()
fig_combined, axs = plt.subplots(3, 3, figsize=(24, 18))
# Make the three visualization plots for saving TOGETHER
visualize_splits_hist(train_lengths=train_lengths,
val_lengths=val_lengths,
test_lengths=test_lengths,
colormap=colormap, axes=axs[0])
visualize_splits_shannon_entropy(train_sequences=train_sequences,
val_sequences=val_sequences,
test_sequences=test_sequences,
colormap=colormap,axes=axs[1])
visualize_splits_scatter(train_clusters=train_clusters,
val_clusters=val_clusters,
test_clusters=test_clusters,
benchmark_cluster_reps=benchmark_cluster_reps,
colormap=colormap, axes=axs[2, 0])
visualize_splits_aa_composition(train_sequences=train_sequences,
val_sequences=val_sequences,
test_sequences=test_sequences,
colormap=colormap, axes=axs[2, 1])
if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path):
visualize_splits_tsne(train_sequences=train_sequences,
val_sequences=val_sequences,
test_sequences=test_sequences,
colormap=colormap, axes=axs[2, 2])
else:
# Leave the last subplot blank
axs[2, 2].axis('off')
plt.tight_layout()
fig_combined.savefig('splits/combined_plot.png')
log_update(f"\nSaved combined figure to splits/combined_plot.png")
def visualize_train_test_splits(train_clusters, test_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None):
if colormap is None: colormap=default_color_map
# Add length column
train_clusters['member length'] = train_clusters['member seq'].str.len()
test_clusters['member length'] = test_clusters['member seq'].str.len()
# Prepare lengths and seqs for plotting
train_lengths = train_clusters['member length'].tolist()
test_lengths = test_clusters['member length'].tolist()
train_sequences = train_clusters['member seq'].tolist()
test_sequences = test_clusters['member seq'].tolist()
# Create a combined figure with 4 rows and 2 columns if TSNE plot, 3 x 2 otherwise
if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path):
set_font()
fig_combined, axs = plt.subplots(4, 2, figsize=(18, 36))
visualize_splits_tsne(train_sequences=train_sequences,
val_sequences=None,
test_sequences=test_sequences,
colormap=colormap, axes=axs[3, 0])
axs[-1,1].axis('off')
else:
set_font()
fig_combined, axs = plt.subplots(3, 2, figsize=(18, 18))
# Make the three visualization plots for saving TOGETHER
visualize_splits_hist(train_lengths=train_lengths,
val_lengths=None,
test_lengths=test_lengths,
colormap=colormap, axes=axs[0])
visualize_splits_shannon_entropy(train_sequences=train_sequences,
val_sequences=None,
test_sequences=test_sequences,
colormap=colormap,axes=axs[1])
visualize_splits_scatter(train_clusters=train_clusters,
val_clusters=None,
test_clusters=test_clusters,
benchmark_cluster_reps=benchmark_cluster_reps,
colormap=colormap, axes=axs[2, 0])
visualize_splits_aa_composition(train_sequences=train_sequences,
val_sequences=None,
test_sequences=test_sequences,
colormap=colormap, axes=axs[2, 1])
plt.tight_layout()
fig_combined.savefig('splits/combined_plot.png')
log_update(f"\nSaved combined figure to splits/combined_plot.png")
def visualize_train_val_splits(train_clusters, val_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None):
if colormap is None: colormap=default_color_map
# Add length column
train_clusters['member length'] = train_clusters['member seq'].str.len()
val_clusters['member length'] = val_clusters['member seq'].str.len()
# Prepare lengths and seqs for plotting
train_lengths = train_clusters['member length'].tolist()
val_lengths = val_clusters['member length'].tolist()
train_sequences = train_clusters['member seq'].tolist()
val_sequences = val_clusters['member seq'].tolist()
# Create a combined figure with 4 rows and 2 columns if TSNE plot, 3 x 2 otherwise
if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path):
set_font()
fig_combined, axs = plt.subplots(4, 2, figsize=(18, 36))
visualize_splits_tsne(train_sequences=train_sequences,
val_sequences=val_sequences,
test_sequences=None,
colormap=colormap, axes=axs[3, 0])
axs[-1,1].axis('off')
else:
set_font()
fig_combined, axs = plt.subplots(3, 2, figsize=(18, 18))
# Make the three visualization plots for saving TOGETHER
visualize_splits_hist(train_lengths=train_lengths,
val_lengths=val_lengths,
test_lengths=None,
colormap=colormap, axes=axs[0])
visualize_splits_shannon_entropy(train_sequences=train_sequences,
val_sequences=val_sequences,
test_sequences=None,
colormap=colormap,axes=axs[1])
visualize_splits_scatter(train_clusters=train_clusters,
val_clusters=val_clusters,
test_clusters=None,
benchmark_cluster_reps=benchmark_cluster_reps,
colormap=colormap, axes=axs[2, 0])
visualize_splits_aa_composition(train_sequences=train_sequences,
val_sequences=val_sequences,
test_sequences=None,
colormap=colormap, axes=axs[2, 1])
plt.tight_layout()
fig_combined.savefig('splits/combined_plot.png')
log_update(f"\nSaved combined figure to splits/combined_plot.png")