import matplotlib.pyplot as plt import numpy as np from scipy.stats import entropy from sklearn.manifold import TSNE import pickle import pandas as pd import os from fuson_plm.utils.logging import log_update from fuson_plm.utils.visualizing import set_font def calculate_aa_composition(sequences): composition = {} total_length = sum([len(seq) for seq in sequences]) for seq in sequences: for aa in seq: if aa in composition: composition[aa] += 1 else: composition[aa] = 1 # Convert counts to relative frequency for aa in composition: composition[aa] /= total_length return composition def calculate_shannon_entropy(sequence): """ Calculate the Shannon entropy for a given sequence. Args: sequence (str): A sequence of characters (e.g., amino acids or nucleotides). Returns: float: Shannon entropy value. """ bases = set(sequence) counts = [sequence.count(base) for base in bases] return entropy(counts, base=2) def visualize_splits_hist(train_lengths, val_lengths, test_lengths, colormap, savepath=f'../data/splits/length_distributions.png', axes=None): log_update('\nMaking histogram of length distributions') # Create a figure and axes with 1 row and 3 columns if axes is None: fig, axes = plt.subplots(1, 3, figsize=(18, 6)) # Unpack the labels and titles xlabel, ylabel = ['Sequence Length (AA)', 'Frequency'] # Plot the first histogram axes[0].hist(train_lengths, bins=20, edgecolor='k',color=colormap['train']) axes[0].set_xlabel(xlabel, fontsize=24) axes[0].set_ylabel(ylabel, fontsize=24) axes[0].set_title(f'Train Set Length Distribution (n={len(train_lengths)})', fontsize=24) axes[0].grid(True) axes[0].set_axisbelow(True) axes[0].tick_params(axis='x', labelsize=24) # Customize x-axis tick label size axes[0].tick_params(axis='y', labelsize=24) # Customize y-axis tick label size # Plot the second histogram axes[1].hist(val_lengths, bins=20, edgecolor='k',color=colormap['val']) axes[1].set_xlabel(xlabel, fontsize=24) axes[1].set_ylabel(ylabel, fontsize=24) axes[1].set_title(f'Validation Set Length Distribution (n={len(val_lengths)})', fontsize=24) axes[1].grid(True) axes[1].set_axisbelow(True) axes[1].tick_params(axis='x', labelsize=24) axes[1].tick_params(axis='y', labelsize=24) # Plot the third histogram axes[2].hist(test_lengths, bins=20, edgecolor='k',color=colormap['test']) axes[2].set_xlabel(xlabel, fontsize=24) axes[2].set_ylabel(ylabel, fontsize=24) axes[2].set_title(f'Test Set Length Distribution (n={len(test_lengths)})', fontsize=24) axes[2].grid(True) axes[2].set_axisbelow(True) axes[2].tick_params(axis='x', labelsize=24) axes[2].tick_params(axis='y', labelsize=24) # Adjust layout if savepath is not None: plt.tight_layout() # Save the figure plt.savefig(savepath) def visualize_splits_scatter(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps, colormap, savepath='../data/splits/scatterplot.png', ax=None): log_update("\nMaking scatterplot with distribution of cluster sizes across train, test, and val") # Make grouped versions of these DataFrames for size analysis train_clustersgb = train_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) val_clustersgb = val_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) test_clustersgb = test_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'}) # Isolate benchmark-containing clusters so their contribution can be plotted separately total_test_proteins = sum(test_clustersgb['member count']) test_clustersgb['benchmark cluster'] = test_clustersgb['representative seq_id'].isin(benchmark_cluster_reps) benchmark_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']].reset_index(drop=True) test_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']==False].reset_index(drop=True) # Convert them to value counts train_clustersgb = train_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'}) val_clustersgb = val_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'}) test_clustersgb = test_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'}) benchmark_clustersgb = benchmark_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'}) # Get the percentage of each dataset that's made of each cluster size train_clustersgb['n_proteins'] = train_clustersgb['cluster size (n_members)']*train_clustersgb['n_clusters'] # proteins per cluster * n clusters = # proteins train_clustersgb['percent_proteins'] = train_clustersgb['n_proteins']/sum(train_clustersgb['n_proteins']) val_clustersgb['n_proteins'] = val_clustersgb['cluster size (n_members)']*val_clustersgb['n_clusters'] val_clustersgb['percent_proteins'] = val_clustersgb['n_proteins']/sum(val_clustersgb['n_proteins']) test_clustersgb['n_proteins'] = test_clustersgb['cluster size (n_members)']*test_clustersgb['n_clusters'] test_clustersgb['percent_proteins'] = test_clustersgb['n_proteins']/total_test_proteins benchmark_clustersgb['n_proteins'] = benchmark_clustersgb['cluster size (n_members)']*benchmark_clustersgb['n_clusters'] benchmark_clustersgb['percent_proteins'] = benchmark_clustersgb['n_proteins']/total_test_proteins # Specially mark the benchmark clusters because these can't be reallocated if ax is None: fig, ax = plt.subplots(figsize=(18, 6)) ax.plot(train_clustersgb['cluster size (n_members)'],train_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['train'],label='train') ax.plot(val_clustersgb['cluster size (n_members)'],val_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['val'],label='val') ax.plot(test_clustersgb['cluster size (n_members)'],test_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['test'],label='test') ax.plot(benchmark_clustersgb['cluster size (n_members)'],benchmark_clustersgb['percent_proteins'], marker='o', linestyle='None', markerfacecolor=colormap['test'], # fill same as test markeredgecolor='black', # outline black markeredgewidth=1.5, label='benchmark' ) ax.set_ylabel('Percentage of Proteins in Dataset', fontsize=24) ax.set_xlabel('Cluster Size', fontsize=24) ax.tick_params(axis='x', labelsize=24) # Customize x-axis tick label size ax.tick_params(axis='y', labelsize=24) # Customize y-axis tick label size ax.legend(fontsize=24,markerscale=4) # save the figure if savepath is not None: plt.tight_layout() plt.savefig(savepath) log_update(f"\tSaved figure to {savepath}") def get_avg_embeddings_for_tsne(train_sequences, val_sequences, test_sequences, embedding_path='fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl'): embeddings = {} try: with open(embedding_path, 'rb') as f: embeddings = pickle.load(f) train_embeddings = [v for k, v in embeddings.items() if k in train_sequences] val_embeddings = [v for k, v in embeddings.items() if k in val_sequences] test_embeddings = [v for k, v in embeddings.items() if k in test_sequences] return train_embeddings, val_embeddings, test_embeddings except: print("could not open embeddings") def visualize_splits_tsne(train_sequences, val_sequences, test_sequences, colormap, esm_type="esm2_t33_650M_UR50D", embedding_path="fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl", savepath='../data/splits/tsne_plot.png',ax=None): """ Generate a t-SNE plot of embeddings for train, test, and validation. """ log_update('\nMaking t-SNE plot of train, val, and test embeddings') # Combine the embeddings into one array train_embeddings, val_embeddings, test_embeddings = get_avg_embeddings_for_tsne(train_sequences, val_sequences, test_sequences, embedding_path=embedding_path) embeddings = np.concatenate([train_embeddings, val_embeddings, test_embeddings]) # Labels for the embeddings labels = ['train'] * len(train_embeddings) + ['val'] * len(val_embeddings) + ['test'] * len(test_embeddings) # Perform t-SNE tsne = TSNE(n_components=2, random_state=42) tsne_results = tsne.fit_transform(embeddings) # Convert t-SNE results into a DataFrame tsne_df = pd.DataFrame(data=tsne_results, columns=['TSNE_1', 'TSNE_2']) tsne_df['label'] = labels # Plotting if ax is None: fig, ax = plt.subplots(figsize=(10, 8)) # Scatter plot for each set for label, color in colormap.items(): subset = tsne_df[tsne_df['label'] == label].reset_index(drop=True) ax.scatter(subset['TSNE_1'], subset['TSNE_2'], c=color, label=label.capitalize(), alpha=0.6) ax.set_title(f't-SNE of {esm_type} Embeddings') ax.set_xlabel('t-SNE Dimension 1') ax.set_ylabel('t-SNE Dimension 2') ax.legend(fontsize=24, markerscale=2) ax.grid(True) # Save the figure if savepath is provided if savepath: plt.tight_layout() fig.savefig(savepath) def visualize_splits_shannon_entropy(train_sequences, val_sequences, test_sequences, colormap, savepath='../data/splits/shannon_entropy_plot.png',axes=None): """ Generate Shannon entropy plots for train, validation, and test sets. """ log_update('\nMaking histogram of Shannon Entropy distributions') train_entropy = [calculate_shannon_entropy(seq) for seq in train_sequences] val_entropy = [calculate_shannon_entropy(seq) for seq in val_sequences] test_entropy = [calculate_shannon_entropy(seq) for seq in test_sequences] if axes is None: fig, axes = plt.subplots(1, 3, figsize=(18, 6)) axes[0].hist(train_entropy, bins=20, edgecolor='k', color=colormap['train']) axes[0].set_title(f'Train Set (n={len(train_entropy)})', fontsize=24) axes[0].set_xlabel('Shannon Entropy', fontsize=24) axes[0].set_ylabel('Frequency', fontsize=24) axes[0].grid(True) axes[0].set_axisbelow(True) axes[0].tick_params(axis='x', labelsize=24) axes[0].tick_params(axis='y', labelsize=24) axes[1].hist(val_entropy, bins=20, edgecolor='k', color=colormap['val']) axes[1].set_title(f'Validation Set (n={len(val_entropy)})', fontsize=24) axes[1].set_xlabel('Shannon Entropy', fontsize=24) axes[1].grid(True) axes[1].set_axisbelow(True) axes[1].tick_params(axis='x', labelsize=24) axes[1].tick_params(axis='y', labelsize=24) axes[2].hist(test_entropy, bins=20, edgecolor='k', color=colormap['test']) axes[2].set_title(f'Test Set (n={len(test_entropy)})', fontsize=24) axes[2].set_xlabel('Shannon Entropy', fontsize=24) axes[2].grid(True) axes[2].set_axisbelow(True) axes[2].tick_params(axis='x', labelsize=24) axes[2].tick_params(axis='y', labelsize=24) if savepath is not None: plt.tight_layout() plt.savefig(savepath) def visualize_splits_aa_composition(train_sequences, val_sequences, test_sequences,colormap, savepath='../data/splits/aa_comp.png',ax=None): log_update('\nMaking bar plot of AA composition across each set') train_comp = calculate_aa_composition(train_sequences) val_comp = calculate_aa_composition(val_sequences) test_comp = calculate_aa_composition(test_sequences) # Create DataFrame comp_df = pd.DataFrame([train_comp, val_comp, test_comp], index=['train', 'val', 'test']).T colors = [colormap[col] for col in comp_df.columns] # Plotting #fig, ax = plt.subplots(figsize=(12, 6)) if ax is None: fig, ax = plt.subplots(figsize=(12, 6)) else: fig = ax.get_figure() comp_df.plot(kind='bar', color=colors, ax=ax) ax.set_title('Amino Acid Composition Across Datasets', fontsize=24) ax.set_xlabel('Amino Acid', fontsize=24) ax.set_ylabel('Relative Frequency', fontsize=24) ax.tick_params(axis='x', labelsize=24) # Customize x-axis tick label size ax.tick_params(axis='y', labelsize=24) # Customize y-axis tick label size ax.legend(fontsize=16, markerscale=2) if savepath is not None: fig.savefig(savepath) def visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps, train_color='#0072B2',val_color='#009E73',test_color='#E69F00',esm_embeddings_path=None, onehot_embeddings_path=None): colormap = { 'train': train_color, 'val': val_color, 'test': test_color } # Add columns for plotting train_clusters['member length'] = train_clusters['member seq'].str.len() val_clusters['member length'] = val_clusters['member seq'].str.len() test_clusters['member length'] = test_clusters['member seq'].str.len() # Prepare lengths and seqs for plotting train_lengths = train_clusters['member length'].tolist() val_lengths = val_clusters['member length'].tolist() test_lengths = test_clusters['member length'].tolist() train_sequences = train_clusters['member seq'].tolist() val_sequences = val_clusters['member seq'].tolist() test_sequences = test_clusters['member seq'].tolist() # Create a combined figure with 3 rows and 3 columns fig_combined, axs = plt.subplots(3, 3, figsize=(24, 18)) # Make the three visualization plots for saving TOGETHER visualize_splits_hist(train_lengths,val_lengths,test_lengths,colormap, savepath=None,axes=axs[0]) visualize_splits_shannon_entropy(train_sequences,val_sequences,test_sequences,colormap,savepath=None,axes=axs[1]) visualize_splits_scatter(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps, colormap, savepath=None, ax=axs[2, 0]) visualize_splits_aa_composition(train_sequences,val_sequences,test_sequences, colormap, savepath=None, ax=axs[2, 1]) if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path): visualize_splits_tsne(train_sequences, val_sequences, test_sequences, colormap, savepath=None, ax=axs[2, 2]) else: # Leave the last subplot blank axs[2, 2].axis('off') plt.tight_layout() fig_combined.savefig('../data/splits/combined_plot.png') # Make the three visualization plots for saving separately visualize_splits_hist(train_clusters['member length'].tolist(), val_clusters['member length'].tolist(), test_clusters['member length'].tolist(),colormap) visualize_splits_scatter(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps, colormap) visualize_splits_aa_composition(train_clusters['member seq'].tolist(), val_clusters['member seq'].tolist(), test_clusters['member seq'].tolist(),colormap) visualize_splits_shannon_entropy(train_sequences,val_sequences,test_sequences,colormap) if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path): visualize_splits_tsne(train_clusters['member seq'].tolist(), val_clusters['member seq'].tolist(), test_clusters['member seq'].tolist(),colormap) def main(): set_font() train_clusters = pd.read_csv('splits/train_cluster_split.csv') val_clusters = pd.read_csv('splits/val_cluster_split.csv') test_clusters = pd.read_csv('splits/test_cluster_split.csv') clusters = pd.concat([train_clusters,val_clusters,test_clusters]) fuson_db = pd.read_csv('fuson_db.csv') # Get the sequence IDs of all clustered benchmark sequences. benchmark_seq_ids = fuson_db.loc[fuson_db['benchmark'].notna()]['seq_id'] # Use benchmark_seq_ids to find which clusters contain benchmark sequences. benchmark_cluster_reps = clusters.loc[clusters['member seq_id'].isin(benchmark_seq_ids)]['representative seq_id'].unique().tolist() visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps, esm_embeddings_path='fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl', onehot_embeddings_path=None) if __name__ == "__main__": main()