import matplotlib.pyplot as plt import matplotlib.font_manager as fm from matplotlib.font_manager import FontProperties from scipy.stats import entropy from sklearn.manifold import TSNE import pickle import pandas as pd import os import numpy as np from fuson_plm.utils.logging import log_update, find_fuson_plm_directory def set_font(): # Load and set the font fuson_plm_dir = find_fuson_plm_directory() # Paths for regular, bold, italic fonts regular_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Regular.ttf') bold_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Bold.ttf') italic_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Italic.ttf') bold_italic_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-BoldItalic.ttf') # Load the font properties regular_font = FontProperties(fname=regular_font_path) bold_font = FontProperties(fname=bold_font_path) italic_font = FontProperties(fname=italic_font_path) bold_italic_font = FontProperties(fname=bold_italic_font_path) # Add the fonts to the font manager fm.fontManager.addfont(regular_font_path) fm.fontManager.addfont(bold_font_path) fm.fontManager.addfont(italic_font_path) fm.fontManager.addfont(bold_italic_font_path) # Set the font family globally to Ubuntu plt.rcParams['font.family'] = regular_font.get_name() # Set the fonts for math text (like for labels) to use the loaded Ubuntu fonts plt.rcParams['mathtext.fontset'] = 'custom' plt.rcParams['mathtext.rm'] = regular_font.get_name() plt.rcParams['mathtext.it'] = f'{italic_font.get_name()}' plt.rcParams['mathtext.bf'] = f'{bold_font.get_name()}' global default_color_map default_color_map = { 'train': '#0072B2', 'val': '#009E73', 'test': '#E69F00' } def get_avg_embeddings_for_tsne(train_sequences=None, val_sequences=None, test_sequences=None, embedding_path='fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl'): if train_sequences is None: train_sequences = [] if val_sequences is None: val_sequences = [] if test_sequences is None: test_sequences = [] embeddings = {} try: with open(embedding_path, 'rb') as f: embeddings = pickle.load(f) train_embeddings = [v for k, v in embeddings.items() if k in train_sequences] val_embeddings = [v for k, v in embeddings.items() if k in val_sequences] test_embeddings = [v for k, v in embeddings.items() if k in test_sequences] return train_embeddings, val_embeddings, test_embeddings except: print("could not open embeddings") def calculate_aa_composition(sequences): composition = {} total_length = sum([len(seq) for seq in sequences]) for seq in sequences: for aa in seq: if aa in composition: composition[aa] += 1 else: composition[aa] = 1 # Convert counts to relative frequency for aa in composition: composition[aa] /= total_length return composition def calculate_shannon_entropy(sequence): """ Calculate the Shannon entropy for a given sequence. Args: sequence (str): A sequence of characters (e.g., amino acids or nucleotides). Returns: float: Shannon entropy value. """ bases = set(sequence) counts = [sequence.count(base) for base in bases] return entropy(counts, base=2) def visualize_splits_hist(train_lengths=None, val_lengths=None, test_lengths=None, colormap=None, savepath=f'splits/length_distributions.png', axes=None): """ Works to plot train, val, test; train, val; or train, test """ set_font() if colormap is None: colormap=default_color_map log_update('\nMaking histogram of length distributions') # Get index for test plot val_plot_index, test_plot_index, total_plots = 1, 2, 3 if val_lengths is None: val_plot_index = None test_plot_index-= 1 total_plots-=1 if test_lengths is None: test_plot_index = None total_plots-=1 # Create a figure and axes with 1 row and 3 columns fig_individual, axes_individual = plt.subplots(1, total_plots, figsize=(6*total_plots, 6)) # Set axes list axes_list = [axes_individual] if axes is None else [axes_individual, axes] # Unpack the labels and titles xlabel, ylabel = ['Sequence Length (AA)', 'Frequency'] for cur_axes in axes_list: # Plot the first histogram cur_axes[0].hist(train_lengths, bins=20, edgecolor='k',color=colormap['train']) cur_axes[0].set_xlabel(xlabel) cur_axes[0].set_ylabel(ylabel) cur_axes[0].set_title(f'Train Set Length Distribution (n={len(train_lengths)})') cur_axes[0].grid(True) cur_axes[0].set_axisbelow(True) # Plot the second histogram if not(val_plot_index is None): cur_axes[val_plot_index].hist(val_lengths, bins=20, edgecolor='k',color=colormap['val']) cur_axes[val_plot_index].set_xlabel(xlabel) cur_axes[val_plot_index].set_ylabel(ylabel) cur_axes[val_plot_index].set_title(f'Validation Set Length Distribution (n={len(val_lengths)})') cur_axes[val_plot_index].grid(True) cur_axes[val_plot_index].set_axisbelow(True) # Plot the third histogram if not(test_plot_index is None): cur_axes[test_plot_index].hist(test_lengths, bins=20, edgecolor='k',color=colormap['test']) cur_axes[test_plot_index].set_xlabel(xlabel) cur_axes[test_plot_index].set_ylabel(ylabel) cur_axes[test_plot_index].set_title(f'Test Set Length Distribution (n={len(test_lengths)})') cur_axes[test_plot_index].grid(True) cur_axes[test_plot_index].set_axisbelow(True) # Adjust layout fig_individual.set_tight_layout(True) # Save the figure fig_individual.savefig(savepath) log_update(f"\tSaved figure to {savepath}") def visualize_splits_scatter(train_clusters=None, val_clusters=None, test_clusters=None, benchmark_cluster_reps=None, colormap=None, savepath='splits/scatterplot.png', axes=None): set_font() if colormap is None: colormap=default_color_map # Create a figure and axes with 1 row and 3 columns fig_individual, axes_individual = plt.subplots(figsize=(18, 6)) # Set axes list axes_list = [axes_individual] if axes is None else [axes_individual, axes] log_update("\nMaking scatterplot with distribution of cluster sizes across train, test, and val") # Make grouped versions of these DataFrames for size analysis train_clustersgb = train_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) if not(val_clusters is None): val_clustersgb = val_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) if not(test_clusters is None): test_clustersgb = test_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) # Isolate benchmark-containing clusters so their contribution can be plotted separately total_test_proteins = sum(test_clustersgb['member count']) if not(benchmark_cluster_reps is None): test_clustersgb['benchmark cluster'] = test_clustersgb['representative seq_id'].isin(benchmark_cluster_reps) benchmark_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']].reset_index(drop=True) test_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']==False].reset_index(drop=True) # Convert them to value counts train_clustersgb = train_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'}) if not(val_clusters is None): val_clustersgb = val_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'}) if not(test_clusters is None): test_clustersgb = test_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'}) if not(benchmark_cluster_reps is None): benchmark_clustersgb = benchmark_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'}) # Get the percentage of each dataset that's made of each cluster size train_clustersgb['n_proteins'] = train_clustersgb['cluster size (n_members)']*train_clustersgb['n_clusters'] # proteins per cluster * n clusters = # proteins train_clustersgb['percent_proteins'] = train_clustersgb['n_proteins']/sum(train_clustersgb['n_proteins']) if not(val_clusters is None): val_clustersgb['n_proteins'] = val_clustersgb['cluster size (n_members)']*val_clustersgb['n_clusters'] val_clustersgb['percent_proteins'] = val_clustersgb['n_proteins']/sum(val_clustersgb['n_proteins']) if not(test_clusters is None): test_clustersgb['n_proteins'] = test_clustersgb['cluster size (n_members)']*test_clustersgb['n_clusters'] test_clustersgb['percent_proteins'] = test_clustersgb['n_proteins']/total_test_proteins if not(benchmark_cluster_reps is None): benchmark_clustersgb['n_proteins'] = benchmark_clustersgb['cluster size (n_members)']*benchmark_clustersgb['n_clusters'] benchmark_clustersgb['percent_proteins'] = benchmark_clustersgb['n_proteins']/total_test_proteins # Specially mark the benchmark clusters because these can't be reallocated for ax in axes_list: ax.plot(train_clustersgb['cluster size (n_members)'],train_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['train'],label='train') if not(val_clusters is None): ax.plot(val_clustersgb['cluster size (n_members)'],val_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['val'],label='val') if not(test_clusters is None): ax.plot(test_clustersgb['cluster size (n_members)'],test_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['test'],label='test') if not(benchmark_cluster_reps is None): ax.plot(benchmark_clustersgb['cluster size (n_members)'],benchmark_clustersgb['percent_proteins'], marker='o', linestyle='None', markerfacecolor=colormap['test'], # fill same as test markeredgecolor='black', # outline black markeredgewidth=1.5, label='benchmark' ) ax.set(ylabel='Percentage of Proteins in Dataset',xlabel='cluster_size') ax.legend() # save the figure fig_individual.set_tight_layout(True) fig_individual.savefig(savepath) log_update(f"\tSaved figure to {savepath}") def visualize_splits_tsne(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, esm_type="esm2_t33_650M_UR50D", embedding_path="fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl", savepath='splits/tsne_plot.png',axes=None): set_font() if colormap is None: colormap=default_color_map """ Generate a t-SNE plot of embeddings for train, test, and validation. """ log_update('\nMaking t-SNE plot of train, val, and test embeddings') # Create a figure and axes with 1 row and 3 columns fig_individual, axes_individual = plt.subplots(figsize=(18, 6)) # Set axes list axes_list = [axes_individual] if axes is None else [axes_individual, axes] # Combine the embeddings into one array train_embeddings, val_embeddings, test_embeddings = get_avg_embeddings_for_tsne(train_sequences=train_sequences, val_sequences=val_sequences, test_sequences=test_sequences, embedding_path=embedding_path) if not(val_embeddings is None) and not(test_embeddings is None): embeddings = np.concatenate([train_embeddings, val_embeddings, test_embeddings]) labels = ['train'] * len(train_embeddings) + ['val'] * len(val_embeddings) + ['test'] * len(test_embeddings) if not(val_embeddings is None) and (test_embeddings is None): embeddings = np.concatenate([train_embeddings, val_embeddings]) labels = ['train'] * len(train_embeddings) + ['val'] * len(val_embeddings) if (val_embeddings is None) and not(test_embeddings is None): embeddings = np.concatenate([train_embeddings, test_embeddings]) labels = ['train'] * len(train_embeddings) + ['test'] * len(test_embeddings) # Perform t-SNE tsne = TSNE(n_components=2, random_state=42) tsne_results = tsne.fit_transform(embeddings) # Convert t-SNE results into a DataFrame tsne_df = pd.DataFrame(data=tsne_results, columns=['TSNE_1', 'TSNE_2']) tsne_df['label'] = labels for ax in axes_list: # Scatter plot for each set for label, color in colormap.items(): subset = tsne_df[tsne_df['label'] == label].reset_index(drop=True) ax.scatter(subset['TSNE_1'], subset['TSNE_2'], c=color, label=label.capitalize(), alpha=0.6) ax.set_title(f't-SNE of {esm_type} Embeddings') ax.set_xlabel('t-SNE Dimension 1') ax.set_ylabel('t-SNE Dimension 2') ax.legend() ax.grid(True) # Save the figure if savepath is provided fig_individual.set_tight_layout(True) fig_individual.savefig(savepath) log_update(f"\tSaved figure to {savepath}") def visualize_splits_shannon_entropy(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, savepath='splits/shannon_entropy_plot.png',axes=None): set_font() """ Generate Shannon entropy plots for train, validation, and test sets. """ # Get index for test plot val_plot_index, test_plot_index, total_plots = 1, 2, 3 if val_sequences is None: val_plot_index = None test_plot_index-= 1 total_plots-=1 if test_sequences is None: test_plot_index = None total_plots-=1 if colormap is None: colormap=default_color_map # Create a figure and axes with 1 row and 3 columns fig_individual, axes_individual = plt.subplots(1, total_plots, figsize=(6*total_plots, 6)) # Set axes list axes_list = [axes_individual] if axes is None else [axes_individual, axes] log_update('\nMaking histogram of Shannon Entropy distributions') train_entropy = [calculate_shannon_entropy(seq) for seq in train_sequences] if not(val_plot_index is None): val_entropy = [calculate_shannon_entropy(seq) for seq in val_sequences] if not(test_plot_index is None): test_entropy = [calculate_shannon_entropy(seq) for seq in test_sequences] for ax in axes_list: ax[0].hist(train_entropy, bins=20, edgecolor='k', color=colormap['train']) ax[0].set_title(f'Train Set (n={len(train_entropy)})') ax[0].set_xlabel('Shannon Entropy') ax[0].set_ylabel('Frequency') ax[0].grid(True) ax[0].set_axisbelow(True) if not(val_plot_index is None): ax[val_plot_index].hist(val_entropy, bins=20, edgecolor='k', color=colormap['val']) ax[val_plot_index].set_title(f'Validation Set (n={len(val_entropy)})') ax[val_plot_index].set_xlabel('Shannon Entropy') ax[val_plot_index].grid(True) ax[val_plot_index].set_axisbelow(True) if not(test_plot_index is None): ax[test_plot_index].hist(test_entropy, bins=20, edgecolor='k', color=colormap['test']) ax[test_plot_index].set_title(f'Test Set (n={len(test_entropy)})') ax[test_plot_index].set_xlabel('Shannon Entropy') ax[test_plot_index].grid(True) ax[test_plot_index].set_axisbelow(True) fig_individual.set_tight_layout(True) fig_individual.savefig(savepath) log_update(f"\tSaved figure to {savepath}") def visualize_splits_aa_composition(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, savepath='splits/aa_comp.png',axes=None): set_font() if colormap is None: colormap=default_color_map # Create a figure and axes with 1 row and 3 columns fig_individual, axes_individual = plt.subplots(figsize=(18, 6)) # Set axes list axes_list = [axes_individual] if axes is None else [axes_individual, axes] log_update('\nMaking bar plot of AA composition across each set') train_comp = calculate_aa_composition(train_sequences) if not(val_sequences is None): val_comp = calculate_aa_composition(val_sequences) if not(test_sequences is None): test_comp = calculate_aa_composition(test_sequences) # Create DataFrame if not(val_sequences is None) and not(test_sequences is None): comp_df = pd.DataFrame([train_comp, val_comp, test_comp], index=['train', 'val', 'test']).T if not(val_sequences is None) and (test_sequences is None): comp_df = pd.DataFrame([train_comp, val_comp], index=['train', 'val']).T if (val_sequences is None) and not(test_sequences is None): comp_df = pd.DataFrame([train_comp, test_comp], index=['train', 'test']).T colors = [colormap[col] for col in comp_df.columns] # Plotting for ax in axes_list: comp_df.plot(kind='bar', color=colors, ax=ax) ax.set_title('Amino Acid Composition Across Datasets') ax.set_xlabel('Amino Acid') ax.set_ylabel('Relative Frequency') fig_individual.set_tight_layout(True) fig_individual.savefig(savepath) log_update(f"\tSaved figure to {savepath}") ### Outer methods for visualizing splits def visualize_splits(train_clusters=None, val_clusters=None, test_clusters=None, benchmark_cluster_reps=None, train_color='#0072B2',val_color='#009E73',test_color='#E69F00',esm_embeddings_path=None, onehot_embeddings_path=None): colormap = { 'train': train_color, 'val': val_color, 'test': test_color } valid_entry = False # Add columns for plotting if not(train_clusters is None) and not(val_clusters is None) and not(test_clusters is None): visualize_train_val_test_splits(train_clusters, val_clusters, test_clusters,benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap) valid_entry=True if not(train_clusters is None) and (val_clusters is None) and not(test_clusters is None): visualize_train_test_splits(train_clusters, test_clusters, benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap) valid_entry=True if not(train_clusters is None) and not(val_clusters is None) and (test_clusters is None): visualize_train_val_splits(train_clusters, val_clusters, benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap) valid_entry=True if not(valid_entry): raise Exception("Must pass train and at least one of val or test") def visualize_train_val_test_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None): if colormap is None: colormap=default_color_map # Add length column train_clusters['member length'] = train_clusters['member seq'].str.len() val_clusters['member length'] = val_clusters['member seq'].str.len() test_clusters['member length'] = test_clusters['member seq'].str.len() # Prepare lengths and seqs for plotting train_lengths = train_clusters['member length'].tolist() val_lengths = val_clusters['member length'].tolist() test_lengths = test_clusters['member length'].tolist() train_sequences = train_clusters['member seq'].tolist() val_sequences = val_clusters['member seq'].tolist() test_sequences = test_clusters['member seq'].tolist() # Create a combined figure with 3 rows and 3 columns set_font() fig_combined, axs = plt.subplots(3, 3, figsize=(24, 18)) # Make the three visualization plots for saving TOGETHER visualize_splits_hist(train_lengths=train_lengths, val_lengths=val_lengths, test_lengths=test_lengths, colormap=colormap, axes=axs[0]) visualize_splits_shannon_entropy(train_sequences=train_sequences, val_sequences=val_sequences, test_sequences=test_sequences, colormap=colormap,axes=axs[1]) visualize_splits_scatter(train_clusters=train_clusters, val_clusters=val_clusters, test_clusters=test_clusters, benchmark_cluster_reps=benchmark_cluster_reps, colormap=colormap, axes=axs[2, 0]) visualize_splits_aa_composition(train_sequences=train_sequences, val_sequences=val_sequences, test_sequences=test_sequences, colormap=colormap, axes=axs[2, 1]) if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path): visualize_splits_tsne(train_sequences=train_sequences, val_sequences=val_sequences, test_sequences=test_sequences, colormap=colormap, axes=axs[2, 2]) else: # Leave the last subplot blank axs[2, 2].axis('off') plt.tight_layout() fig_combined.savefig('splits/combined_plot.png') log_update(f"\nSaved combined figure to splits/combined_plot.png") def visualize_train_test_splits(train_clusters, test_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None): if colormap is None: colormap=default_color_map # Add length column train_clusters['member length'] = train_clusters['member seq'].str.len() test_clusters['member length'] = test_clusters['member seq'].str.len() # Prepare lengths and seqs for plotting train_lengths = train_clusters['member length'].tolist() test_lengths = test_clusters['member length'].tolist() train_sequences = train_clusters['member seq'].tolist() test_sequences = test_clusters['member seq'].tolist() # Create a combined figure with 4 rows and 2 columns if TSNE plot, 3 x 2 otherwise if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path): set_font() fig_combined, axs = plt.subplots(4, 2, figsize=(18, 36)) visualize_splits_tsne(train_sequences=train_sequences, val_sequences=None, test_sequences=test_sequences, colormap=colormap, axes=axs[3, 0]) axs[-1,1].axis('off') else: set_font() fig_combined, axs = plt.subplots(3, 2, figsize=(18, 18)) # Make the three visualization plots for saving TOGETHER visualize_splits_hist(train_lengths=train_lengths, val_lengths=None, test_lengths=test_lengths, colormap=colormap, axes=axs[0]) visualize_splits_shannon_entropy(train_sequences=train_sequences, val_sequences=None, test_sequences=test_sequences, colormap=colormap,axes=axs[1]) visualize_splits_scatter(train_clusters=train_clusters, val_clusters=None, test_clusters=test_clusters, benchmark_cluster_reps=benchmark_cluster_reps, colormap=colormap, axes=axs[2, 0]) visualize_splits_aa_composition(train_sequences=train_sequences, val_sequences=None, test_sequences=test_sequences, colormap=colormap, axes=axs[2, 1]) plt.tight_layout() fig_combined.savefig('splits/combined_plot.png') log_update(f"\nSaved combined figure to splits/combined_plot.png") def visualize_train_val_splits(train_clusters, val_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None): if colormap is None: colormap=default_color_map # Add length column train_clusters['member length'] = train_clusters['member seq'].str.len() val_clusters['member length'] = val_clusters['member seq'].str.len() # Prepare lengths and seqs for plotting train_lengths = train_clusters['member length'].tolist() val_lengths = val_clusters['member length'].tolist() train_sequences = train_clusters['member seq'].tolist() val_sequences = val_clusters['member seq'].tolist() # Create a combined figure with 4 rows and 2 columns if TSNE plot, 3 x 2 otherwise if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path): set_font() fig_combined, axs = plt.subplots(4, 2, figsize=(18, 36)) visualize_splits_tsne(train_sequences=train_sequences, val_sequences=val_sequences, test_sequences=None, colormap=colormap, axes=axs[3, 0]) axs[-1,1].axis('off') else: set_font() fig_combined, axs = plt.subplots(3, 2, figsize=(18, 18)) # Make the three visualization plots for saving TOGETHER visualize_splits_hist(train_lengths=train_lengths, val_lengths=val_lengths, test_lengths=None, colormap=colormap, axes=axs[0]) visualize_splits_shannon_entropy(train_sequences=train_sequences, val_sequences=val_sequences, test_sequences=None, colormap=colormap,axes=axs[1]) visualize_splits_scatter(train_clusters=train_clusters, val_clusters=val_clusters, test_clusters=None, benchmark_cluster_reps=benchmark_cluster_reps, colormap=colormap, axes=axs[2, 0]) visualize_splits_aa_composition(train_sequences=train_sequences, val_sequences=val_sequences, test_sequences=None, colormap=colormap, axes=axs[2, 1]) plt.tight_layout() fig_combined.savefig('splits/combined_plot.png') log_update(f"\nSaved combined figure to splits/combined_plot.png")