|
import matplotlib.pyplot as plt |
|
import matplotlib.font_manager as fm |
|
from matplotlib.font_manager import FontProperties |
|
from scipy.stats import entropy |
|
from sklearn.manifold import TSNE |
|
import pickle |
|
import pandas as pd |
|
import os |
|
import numpy as np |
|
from fuson_plm.utils.logging import log_update, find_fuson_plm_directory |
|
|
|
def set_font(): |
|
|
|
fuson_plm_dir = find_fuson_plm_directory() |
|
|
|
|
|
regular_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Regular.ttf') |
|
bold_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Bold.ttf') |
|
italic_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Italic.ttf') |
|
bold_italic_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-BoldItalic.ttf') |
|
|
|
|
|
regular_font = FontProperties(fname=regular_font_path) |
|
bold_font = FontProperties(fname=bold_font_path) |
|
italic_font = FontProperties(fname=italic_font_path) |
|
bold_italic_font = FontProperties(fname=bold_italic_font_path) |
|
|
|
|
|
fm.fontManager.addfont(regular_font_path) |
|
fm.fontManager.addfont(bold_font_path) |
|
fm.fontManager.addfont(italic_font_path) |
|
fm.fontManager.addfont(bold_italic_font_path) |
|
|
|
|
|
plt.rcParams['font.family'] = regular_font.get_name() |
|
|
|
|
|
plt.rcParams['mathtext.fontset'] = 'custom' |
|
plt.rcParams['mathtext.rm'] = regular_font.get_name() |
|
plt.rcParams['mathtext.it'] = f'{italic_font.get_name()}' |
|
plt.rcParams['mathtext.bf'] = f'{bold_font.get_name()}' |
|
|
|
global default_color_map |
|
default_color_map = { |
|
'train': '#0072B2', |
|
'val': '#009E73', |
|
'test': '#E69F00' |
|
} |
|
|
|
def get_avg_embeddings_for_tsne(train_sequences=None, val_sequences=None, test_sequences=None, embedding_path='fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl'): |
|
if train_sequences is None: train_sequences = [] |
|
if val_sequences is None: val_sequences = [] |
|
if test_sequences is None: test_sequences = [] |
|
|
|
embeddings = {} |
|
|
|
try: |
|
with open(embedding_path, 'rb') as f: |
|
embeddings = pickle.load(f) |
|
|
|
train_embeddings = [v for k, v in embeddings.items() if k in train_sequences] |
|
val_embeddings = [v for k, v in embeddings.items() if k in val_sequences] |
|
test_embeddings = [v for k, v in embeddings.items() if k in test_sequences] |
|
|
|
return train_embeddings, val_embeddings, test_embeddings |
|
except: |
|
print("could not open embeddings") |
|
|
|
|
|
def calculate_aa_composition(sequences): |
|
composition = {} |
|
total_length = sum([len(seq) for seq in sequences]) |
|
|
|
for seq in sequences: |
|
for aa in seq: |
|
if aa in composition: |
|
composition[aa] += 1 |
|
else: |
|
composition[aa] = 1 |
|
|
|
|
|
for aa in composition: |
|
composition[aa] /= total_length |
|
|
|
return composition |
|
|
|
def calculate_shannon_entropy(sequence): |
|
""" |
|
Calculate the Shannon entropy for a given sequence. |
|
|
|
Args: |
|
sequence (str): A sequence of characters (e.g., amino acids or nucleotides). |
|
|
|
Returns: |
|
float: Shannon entropy value. |
|
""" |
|
bases = set(sequence) |
|
counts = [sequence.count(base) for base in bases] |
|
return entropy(counts, base=2) |
|
|
|
def visualize_splits_hist(train_lengths=None, val_lengths=None, test_lengths=None, colormap=None, savepath=f'splits/length_distributions.png', axes=None): |
|
""" |
|
Works to plot train, val, test; train, val; or train, test |
|
""" |
|
set_font() |
|
if colormap is None: colormap=default_color_map |
|
|
|
log_update('\nMaking histogram of length distributions') |
|
|
|
|
|
val_plot_index, test_plot_index, total_plots = 1, 2, 3 |
|
if val_lengths is None: |
|
val_plot_index = None |
|
test_plot_index-= 1 |
|
total_plots-=1 |
|
if test_lengths is None: |
|
test_plot_index = None |
|
total_plots-=1 |
|
|
|
|
|
fig_individual, axes_individual = plt.subplots(1, total_plots, figsize=(6*total_plots, 6)) |
|
|
|
|
|
axes_list = [axes_individual] if axes is None else [axes_individual, axes] |
|
|
|
|
|
xlabel, ylabel = ['Sequence Length (AA)', 'Frequency'] |
|
|
|
for cur_axes in axes_list: |
|
|
|
cur_axes[0].hist(train_lengths, bins=20, edgecolor='k',color=colormap['train']) |
|
cur_axes[0].set_xlabel(xlabel) |
|
cur_axes[0].set_ylabel(ylabel) |
|
cur_axes[0].set_title(f'Train Set Length Distribution (n={len(train_lengths)})') |
|
cur_axes[0].grid(True) |
|
cur_axes[0].set_axisbelow(True) |
|
|
|
|
|
if not(val_plot_index is None): |
|
cur_axes[val_plot_index].hist(val_lengths, bins=20, edgecolor='k',color=colormap['val']) |
|
cur_axes[val_plot_index].set_xlabel(xlabel) |
|
cur_axes[val_plot_index].set_ylabel(ylabel) |
|
cur_axes[val_plot_index].set_title(f'Validation Set Length Distribution (n={len(val_lengths)})') |
|
cur_axes[val_plot_index].grid(True) |
|
cur_axes[val_plot_index].set_axisbelow(True) |
|
|
|
|
|
if not(test_plot_index is None): |
|
cur_axes[test_plot_index].hist(test_lengths, bins=20, edgecolor='k',color=colormap['test']) |
|
cur_axes[test_plot_index].set_xlabel(xlabel) |
|
cur_axes[test_plot_index].set_ylabel(ylabel) |
|
cur_axes[test_plot_index].set_title(f'Test Set Length Distribution (n={len(test_lengths)})') |
|
cur_axes[test_plot_index].grid(True) |
|
cur_axes[test_plot_index].set_axisbelow(True) |
|
|
|
|
|
fig_individual.set_tight_layout(True) |
|
|
|
|
|
fig_individual.savefig(savepath) |
|
log_update(f"\tSaved figure to {savepath}") |
|
|
|
def visualize_splits_scatter(train_clusters=None, val_clusters=None, test_clusters=None, benchmark_cluster_reps=None, colormap=None, savepath='splits/scatterplot.png', axes=None): |
|
set_font() |
|
if colormap is None: colormap=default_color_map |
|
|
|
|
|
fig_individual, axes_individual = plt.subplots(figsize=(18, 6)) |
|
|
|
|
|
axes_list = [axes_individual] if axes is None else [axes_individual, axes] |
|
|
|
log_update("\nMaking scatterplot with distribution of cluster sizes across train, test, and val") |
|
|
|
train_clustersgb = train_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) |
|
if not(val_clusters is None): |
|
val_clustersgb = val_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) |
|
if not(test_clusters is None): |
|
test_clustersgb = test_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) |
|
|
|
total_test_proteins = sum(test_clustersgb['member count']) |
|
if not(benchmark_cluster_reps is None): |
|
test_clustersgb['benchmark cluster'] = test_clustersgb['representative seq_id'].isin(benchmark_cluster_reps) |
|
benchmark_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']].reset_index(drop=True) |
|
test_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']==False].reset_index(drop=True) |
|
|
|
|
|
train_clustersgb = train_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'}) |
|
if not(val_clusters is None): |
|
val_clustersgb = val_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'}) |
|
if not(test_clusters is None): |
|
test_clustersgb = test_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'}) |
|
if not(benchmark_cluster_reps is None): |
|
benchmark_clustersgb = benchmark_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'}) |
|
|
|
|
|
train_clustersgb['n_proteins'] = train_clustersgb['cluster size (n_members)']*train_clustersgb['n_clusters'] |
|
train_clustersgb['percent_proteins'] = train_clustersgb['n_proteins']/sum(train_clustersgb['n_proteins']) |
|
if not(val_clusters is None): |
|
val_clustersgb['n_proteins'] = val_clustersgb['cluster size (n_members)']*val_clustersgb['n_clusters'] |
|
val_clustersgb['percent_proteins'] = val_clustersgb['n_proteins']/sum(val_clustersgb['n_proteins']) |
|
if not(test_clusters is None): |
|
test_clustersgb['n_proteins'] = test_clustersgb['cluster size (n_members)']*test_clustersgb['n_clusters'] |
|
test_clustersgb['percent_proteins'] = test_clustersgb['n_proteins']/total_test_proteins |
|
if not(benchmark_cluster_reps is None): |
|
benchmark_clustersgb['n_proteins'] = benchmark_clustersgb['cluster size (n_members)']*benchmark_clustersgb['n_clusters'] |
|
benchmark_clustersgb['percent_proteins'] = benchmark_clustersgb['n_proteins']/total_test_proteins |
|
|
|
|
|
for ax in axes_list: |
|
ax.plot(train_clustersgb['cluster size (n_members)'],train_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['train'],label='train') |
|
if not(val_clusters is None): |
|
ax.plot(val_clustersgb['cluster size (n_members)'],val_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['val'],label='val') |
|
if not(test_clusters is None): |
|
ax.plot(test_clustersgb['cluster size (n_members)'],test_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['test'],label='test') |
|
if not(benchmark_cluster_reps is None): |
|
ax.plot(benchmark_clustersgb['cluster size (n_members)'],benchmark_clustersgb['percent_proteins'], |
|
marker='o', |
|
linestyle='None', |
|
markerfacecolor=colormap['test'], |
|
markeredgecolor='black', |
|
markeredgewidth=1.5, |
|
label='benchmark' |
|
) |
|
ax.set(ylabel='Percentage of Proteins in Dataset',xlabel='cluster_size') |
|
ax.legend() |
|
|
|
|
|
fig_individual.set_tight_layout(True) |
|
fig_individual.savefig(savepath) |
|
log_update(f"\tSaved figure to {savepath}") |
|
|
|
|
|
def visualize_splits_tsne(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, esm_type="esm2_t33_650M_UR50D", embedding_path="fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl", savepath='splits/tsne_plot.png',axes=None): |
|
set_font() |
|
|
|
if colormap is None: colormap=default_color_map |
|
|
|
""" |
|
Generate a t-SNE plot of embeddings for train, test, and validation. |
|
""" |
|
log_update('\nMaking t-SNE plot of train, val, and test embeddings') |
|
|
|
fig_individual, axes_individual = plt.subplots(figsize=(18, 6)) |
|
|
|
|
|
axes_list = [axes_individual] if axes is None else [axes_individual, axes] |
|
|
|
|
|
train_embeddings, val_embeddings, test_embeddings = get_avg_embeddings_for_tsne(train_sequences=train_sequences, |
|
val_sequences=val_sequences, |
|
test_sequences=test_sequences, embedding_path=embedding_path) |
|
if not(val_embeddings is None) and not(test_embeddings is None): |
|
embeddings = np.concatenate([train_embeddings, val_embeddings, test_embeddings]) |
|
labels = ['train'] * len(train_embeddings) + ['val'] * len(val_embeddings) + ['test'] * len(test_embeddings) |
|
if not(val_embeddings is None) and (test_embeddings is None): |
|
embeddings = np.concatenate([train_embeddings, val_embeddings]) |
|
labels = ['train'] * len(train_embeddings) + ['val'] * len(val_embeddings) |
|
if (val_embeddings is None) and not(test_embeddings is None): |
|
embeddings = np.concatenate([train_embeddings, test_embeddings]) |
|
labels = ['train'] * len(train_embeddings) + ['test'] * len(test_embeddings) |
|
|
|
|
|
tsne = TSNE(n_components=2, random_state=42) |
|
tsne_results = tsne.fit_transform(embeddings) |
|
|
|
|
|
tsne_df = pd.DataFrame(data=tsne_results, columns=['TSNE_1', 'TSNE_2']) |
|
tsne_df['label'] = labels |
|
|
|
for ax in axes_list: |
|
|
|
for label, color in colormap.items(): |
|
subset = tsne_df[tsne_df['label'] == label].reset_index(drop=True) |
|
ax.scatter(subset['TSNE_1'], subset['TSNE_2'], c=color, label=label.capitalize(), alpha=0.6) |
|
|
|
ax.set_title(f't-SNE of {esm_type} Embeddings') |
|
ax.set_xlabel('t-SNE Dimension 1') |
|
ax.set_ylabel('t-SNE Dimension 2') |
|
ax.legend() |
|
ax.grid(True) |
|
|
|
|
|
fig_individual.set_tight_layout(True) |
|
fig_individual.savefig(savepath) |
|
log_update(f"\tSaved figure to {savepath}") |
|
|
|
def visualize_splits_shannon_entropy(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, savepath='splits/shannon_entropy_plot.png',axes=None): |
|
set_font() |
|
""" |
|
Generate Shannon entropy plots for train, validation, and test sets. |
|
""" |
|
|
|
val_plot_index, test_plot_index, total_plots = 1, 2, 3 |
|
if val_sequences is None: |
|
val_plot_index = None |
|
test_plot_index-= 1 |
|
total_plots-=1 |
|
if test_sequences is None: |
|
test_plot_index = None |
|
total_plots-=1 |
|
|
|
if colormap is None: colormap=default_color_map |
|
|
|
fig_individual, axes_individual = plt.subplots(1, total_plots, figsize=(6*total_plots, 6)) |
|
|
|
|
|
axes_list = [axes_individual] if axes is None else [axes_individual, axes] |
|
|
|
log_update('\nMaking histogram of Shannon Entropy distributions') |
|
train_entropy = [calculate_shannon_entropy(seq) for seq in train_sequences] |
|
if not(val_plot_index is None): |
|
val_entropy = [calculate_shannon_entropy(seq) for seq in val_sequences] |
|
if not(test_plot_index is None): |
|
test_entropy = [calculate_shannon_entropy(seq) for seq in test_sequences] |
|
|
|
for ax in axes_list: |
|
ax[0].hist(train_entropy, bins=20, edgecolor='k', color=colormap['train']) |
|
ax[0].set_title(f'Train Set (n={len(train_entropy)})') |
|
ax[0].set_xlabel('Shannon Entropy') |
|
ax[0].set_ylabel('Frequency') |
|
ax[0].grid(True) |
|
ax[0].set_axisbelow(True) |
|
|
|
if not(val_plot_index is None): |
|
ax[val_plot_index].hist(val_entropy, bins=20, edgecolor='k', color=colormap['val']) |
|
ax[val_plot_index].set_title(f'Validation Set (n={len(val_entropy)})') |
|
ax[val_plot_index].set_xlabel('Shannon Entropy') |
|
ax[val_plot_index].grid(True) |
|
ax[val_plot_index].set_axisbelow(True) |
|
|
|
if not(test_plot_index is None): |
|
ax[test_plot_index].hist(test_entropy, bins=20, edgecolor='k', color=colormap['test']) |
|
ax[test_plot_index].set_title(f'Test Set (n={len(test_entropy)})') |
|
ax[test_plot_index].set_xlabel('Shannon Entropy') |
|
ax[test_plot_index].grid(True) |
|
ax[test_plot_index].set_axisbelow(True) |
|
|
|
fig_individual.set_tight_layout(True) |
|
fig_individual.savefig(savepath) |
|
log_update(f"\tSaved figure to {savepath}") |
|
|
|
def visualize_splits_aa_composition(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, savepath='splits/aa_comp.png',axes=None): |
|
set_font() |
|
if colormap is None: colormap=default_color_map |
|
|
|
|
|
fig_individual, axes_individual = plt.subplots(figsize=(18, 6)) |
|
|
|
|
|
axes_list = [axes_individual] if axes is None else [axes_individual, axes] |
|
|
|
log_update('\nMaking bar plot of AA composition across each set') |
|
train_comp = calculate_aa_composition(train_sequences) |
|
if not(val_sequences is None): |
|
val_comp = calculate_aa_composition(val_sequences) |
|
if not(test_sequences is None): |
|
test_comp = calculate_aa_composition(test_sequences) |
|
|
|
|
|
if not(val_sequences is None) and not(test_sequences is None): |
|
comp_df = pd.DataFrame([train_comp, val_comp, test_comp], index=['train', 'val', 'test']).T |
|
if not(val_sequences is None) and (test_sequences is None): |
|
comp_df = pd.DataFrame([train_comp, val_comp], index=['train', 'val']).T |
|
if (val_sequences is None) and not(test_sequences is None): |
|
comp_df = pd.DataFrame([train_comp, test_comp], index=['train', 'test']).T |
|
colors = [colormap[col] for col in comp_df.columns] |
|
|
|
|
|
for ax in axes_list: |
|
comp_df.plot(kind='bar', color=colors, ax=ax) |
|
ax.set_title('Amino Acid Composition Across Datasets') |
|
ax.set_xlabel('Amino Acid') |
|
ax.set_ylabel('Relative Frequency') |
|
|
|
fig_individual.set_tight_layout(True) |
|
fig_individual.savefig(savepath) |
|
log_update(f"\tSaved figure to {savepath}") |
|
|
|
|
|
def visualize_splits(train_clusters=None, val_clusters=None, test_clusters=None, benchmark_cluster_reps=None, train_color='#0072B2',val_color='#009E73',test_color='#E69F00',esm_embeddings_path=None, onehot_embeddings_path=None): |
|
colormap = { |
|
'train': train_color, |
|
'val': val_color, |
|
'test': test_color |
|
} |
|
valid_entry = False |
|
|
|
if not(train_clusters is None) and not(val_clusters is None) and not(test_clusters is None): |
|
visualize_train_val_test_splits(train_clusters, val_clusters, test_clusters,benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap) |
|
valid_entry=True |
|
if not(train_clusters is None) and (val_clusters is None) and not(test_clusters is None): |
|
visualize_train_test_splits(train_clusters, test_clusters, benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap) |
|
valid_entry=True |
|
if not(train_clusters is None) and not(val_clusters is None) and (test_clusters is None): |
|
visualize_train_val_splits(train_clusters, val_clusters, benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap) |
|
valid_entry=True |
|
|
|
if not(valid_entry): raise Exception("Must pass train and at least one of val or test") |
|
|
|
def visualize_train_val_test_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None): |
|
if colormap is None: colormap=default_color_map |
|
|
|
train_clusters['member length'] = train_clusters['member seq'].str.len() |
|
val_clusters['member length'] = val_clusters['member seq'].str.len() |
|
test_clusters['member length'] = test_clusters['member seq'].str.len() |
|
|
|
|
|
train_lengths = train_clusters['member length'].tolist() |
|
val_lengths = val_clusters['member length'].tolist() |
|
test_lengths = test_clusters['member length'].tolist() |
|
train_sequences = train_clusters['member seq'].tolist() |
|
val_sequences = val_clusters['member seq'].tolist() |
|
test_sequences = test_clusters['member seq'].tolist() |
|
|
|
|
|
set_font() |
|
fig_combined, axs = plt.subplots(3, 3, figsize=(24, 18)) |
|
|
|
|
|
visualize_splits_hist(train_lengths=train_lengths, |
|
val_lengths=val_lengths, |
|
test_lengths=test_lengths, |
|
colormap=colormap, axes=axs[0]) |
|
visualize_splits_shannon_entropy(train_sequences=train_sequences, |
|
val_sequences=val_sequences, |
|
test_sequences=test_sequences, |
|
colormap=colormap,axes=axs[1]) |
|
visualize_splits_scatter(train_clusters=train_clusters, |
|
val_clusters=val_clusters, |
|
test_clusters=test_clusters, |
|
benchmark_cluster_reps=benchmark_cluster_reps, |
|
colormap=colormap, axes=axs[2, 0]) |
|
visualize_splits_aa_composition(train_sequences=train_sequences, |
|
val_sequences=val_sequences, |
|
test_sequences=test_sequences, |
|
colormap=colormap, axes=axs[2, 1]) |
|
if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path): |
|
visualize_splits_tsne(train_sequences=train_sequences, |
|
val_sequences=val_sequences, |
|
test_sequences=test_sequences, |
|
colormap=colormap, axes=axs[2, 2]) |
|
else: |
|
|
|
axs[2, 2].axis('off') |
|
|
|
plt.tight_layout() |
|
fig_combined.savefig('splits/combined_plot.png') |
|
log_update(f"\nSaved combined figure to splits/combined_plot.png") |
|
|
|
def visualize_train_test_splits(train_clusters, test_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None): |
|
if colormap is None: colormap=default_color_map |
|
|
|
train_clusters['member length'] = train_clusters['member seq'].str.len() |
|
test_clusters['member length'] = test_clusters['member seq'].str.len() |
|
|
|
|
|
train_lengths = train_clusters['member length'].tolist() |
|
test_lengths = test_clusters['member length'].tolist() |
|
train_sequences = train_clusters['member seq'].tolist() |
|
test_sequences = test_clusters['member seq'].tolist() |
|
|
|
|
|
if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path): |
|
set_font() |
|
fig_combined, axs = plt.subplots(4, 2, figsize=(18, 36)) |
|
visualize_splits_tsne(train_sequences=train_sequences, |
|
val_sequences=None, |
|
test_sequences=test_sequences, |
|
colormap=colormap, axes=axs[3, 0]) |
|
axs[-1,1].axis('off') |
|
else: |
|
set_font() |
|
fig_combined, axs = plt.subplots(3, 2, figsize=(18, 18)) |
|
|
|
|
|
visualize_splits_hist(train_lengths=train_lengths, |
|
val_lengths=None, |
|
test_lengths=test_lengths, |
|
colormap=colormap, axes=axs[0]) |
|
visualize_splits_shannon_entropy(train_sequences=train_sequences, |
|
val_sequences=None, |
|
test_sequences=test_sequences, |
|
colormap=colormap,axes=axs[1]) |
|
visualize_splits_scatter(train_clusters=train_clusters, |
|
val_clusters=None, |
|
test_clusters=test_clusters, |
|
benchmark_cluster_reps=benchmark_cluster_reps, |
|
colormap=colormap, axes=axs[2, 0]) |
|
visualize_splits_aa_composition(train_sequences=train_sequences, |
|
val_sequences=None, |
|
test_sequences=test_sequences, |
|
colormap=colormap, axes=axs[2, 1]) |
|
|
|
plt.tight_layout() |
|
fig_combined.savefig('splits/combined_plot.png') |
|
log_update(f"\nSaved combined figure to splits/combined_plot.png") |
|
|
|
def visualize_train_val_splits(train_clusters, val_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None): |
|
if colormap is None: colormap=default_color_map |
|
|
|
train_clusters['member length'] = train_clusters['member seq'].str.len() |
|
val_clusters['member length'] = val_clusters['member seq'].str.len() |
|
|
|
|
|
train_lengths = train_clusters['member length'].tolist() |
|
val_lengths = val_clusters['member length'].tolist() |
|
train_sequences = train_clusters['member seq'].tolist() |
|
val_sequences = val_clusters['member seq'].tolist() |
|
|
|
|
|
if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path): |
|
set_font() |
|
fig_combined, axs = plt.subplots(4, 2, figsize=(18, 36)) |
|
visualize_splits_tsne(train_sequences=train_sequences, |
|
val_sequences=val_sequences, |
|
test_sequences=None, |
|
colormap=colormap, axes=axs[3, 0]) |
|
axs[-1,1].axis('off') |
|
else: |
|
set_font() |
|
fig_combined, axs = plt.subplots(3, 2, figsize=(18, 18)) |
|
|
|
|
|
visualize_splits_hist(train_lengths=train_lengths, |
|
val_lengths=val_lengths, |
|
test_lengths=None, |
|
colormap=colormap, axes=axs[0]) |
|
visualize_splits_shannon_entropy(train_sequences=train_sequences, |
|
val_sequences=val_sequences, |
|
test_sequences=None, |
|
colormap=colormap,axes=axs[1]) |
|
visualize_splits_scatter(train_clusters=train_clusters, |
|
val_clusters=val_clusters, |
|
test_clusters=None, |
|
benchmark_cluster_reps=benchmark_cluster_reps, |
|
colormap=colormap, axes=axs[2, 0]) |
|
visualize_splits_aa_composition(train_sequences=train_sequences, |
|
val_sequences=val_sequences, |
|
test_sequences=None, |
|
colormap=colormap, axes=axs[2, 1]) |
|
|
|
plt.tight_layout() |
|
fig_combined.savefig('splits/combined_plot.png') |
|
log_update(f"\nSaved combined figure to splits/combined_plot.png") |