scMKL_analysis / plotting.py
ivango17's picture
fixed JASPAR umaps
38f0a6e
#!/usr/bin/env python
import pandas as pd
import numpy as np
import plotly.express as px
def get_best_alpha(stats_df, modality):
'''
Takes a DataFrame of scMKL results and returns the alpha with the best mean AUROC
stats_df: a DataFrame
modality: the modality to find the best alpha for
Returns best alpha for modality
'''
best_alpha["None", "Estrogen Response Early", "Estrogen Response Late", "Protein Secretion", "E2F Targets", "TGF Beta Signaling", "Apical Surface"] = stats_df[(stats_df['Model'] == 'scMKL') & (stats_df['Modality'] == modality)][['Alpha', 'AUROC']].groupby('Alpha')['AUROC'].apply(lambda x: np.mean(x))
best_alpha = best_alpha[best_alpha == np.max(alpha_star)].index[0]
return best_alpha
def format_datatype_grouping(dtype_grouping):
'''
Takes either a list | tuple | str and formats the names to match labels in dataframes
Returns formatted names as list or str
'''
if (type(dtype_grouping) == list) or (type(dtype_grouping) == tuple):
formatted_data = [selection.replace("Hallmark", "hallmark").replace("Cistrome", "cistrome").replace("Motifs", "motifs").replace("Neuronal","neuronal") for selection in dtype_grouping]
else:
formatted_data = dtype_grouping.replace("Hallmark", "hallmark").replace("Cistrome", "cistrome").replace("Motifs", "motifs").replace("Neuronal","neuronal")
return formatted_data
def performance_boxplot(stats_df: pd.DataFrame, dataset: str, modality, metric: str, x_flag = "intersect", x_var = 'Alpha', color_dict = None):
'''
This function will plot a given metric for a given dataset.
stats_df: a DataFrame with columns
dataset: MCF7, T47D, lymphoma, prostate
modality: which modality or modalities should be visualized
metric: which metric should be displayed
Returns a plotnine object
'''
# Formatting modality list
modality = format_datatype_grouping(modality)
# Filtering data frame to desired dataset and modality(s)
stats_df = stats_df[(stats_df['Dataset'] == dataset) & (np.isin(stats_df['Modality'], modality)) & (stats_df['Model'] == 'scMKL')]
if ((type(modality) is list) or (type(modality) is tuple)) and (x_flag == "intersect"):
x_list = np.unique(stats_df[x_var])
for i, mod in enumerate(modality):
x_list = [value for value in x_list if value in np.unique(stats_df[stats_df['Modality'] == mod][x_var])]
stats_df = stats_df[np.isin(stats_df[x_var], x_list)]
if x_flag == 'best':
stats_df = stats_df[stats_df['Alpha Star'] == 'Yes']
modality_alpha_means = {mod : round(np.mean(stats_df[stats_df['Modality'] == mod]['Alpha']), 3) for mod in np.unique(stats_df['Modality'])}
stats_df['Mean Alpha Star'] = stats_df['Modality'].apply(lambda x: modality_alpha_means[x])
x_var = 'Mean Alpha Star' if x_var == 'Alpha' else x_var
if x_var == 'Mean_Number_of_Selected_Groups':
for mod in modality:
stats_df.loc[stats_df['Modality'] == mod, 'Mean_Number_of_Selected_Groups'] = np.mean(stats_df[stats_df['Modality'] == mod]['Number_of_Selected_Groups'])
# Making x_var catagorical for plotting
if (metric == 'RAM_usage') or (metric == 'Inference_time'):
x_var = 'Modality'
else:
stats_df = stats_df.sort_values(by = x_var)
stats_df[x_var] = pd.Categorical(stats_df[x_var], categories = np.unique(stats_df[x_var])) if 'Alpha' not in x_var else pd.Categorical(stats_df[x_var], categories = np.unique(stats_df[x_var])[::-1])
# performance_bp = (ggplot(stats_df, aes(x = x_var, y = metric, fill = 'Modality', label = 'Modality', color = 'Modality'))
# + geom_boxplot()
# + theme_classic()
# # + scale_fill_manual(values = {'ATAC - cistrome' : '#2e61a3', 'ATAC - hallmark' : '#323aa8', 'ATAC - motifs' : "#05426e",
# # 'ATAC_TFIDF - cistrome' : '#32b3b8', 'ATAC_TFIDF - hallmark' : '#349eeb',
# # 'RNA - hallmark' : '#b52a3c',
# # 'GENE SCORES - hallmark' : '#11bd50'},)
# + theme(axis_text_x=element_text(rotation=90))
# + ggtitle(dataset.capitalize() if len(dataset) > 4 else dataset)
# + theme(axis_text_x= element_text(weight = 'bold', size = 10), axis_text_y= element_text(weight = 'bold'))
# # + geom_text()
# # + geom_text(aes(label=after_stat(stats_df['Modality'])), stat="identity", nudge_y=0.125, va="bottom")
# )
# return performance_bp.draw()
if x_var != 'Modality':
max_x = max(np.unique(stats_df[x_var]))
min_x = min(np.unique(stats_df[x_var]))
range_x = max_x - min_x
width_x = range_x * 0.02
else:
width_x = None
performance_bp = px.box(
data_frame = stats_df,
x = x_var,
y = metric,
color = 'Modality',
template = 'plotly_white',
height = 800,
hover_name = 'Modality',
category_orders = {'Modality' : modality},
color_discrete_map = color_dict
).update_traces(width = width_x,
).update_layout(
hovermode = 'x unified',
hoverlabel=dict(
bgcolor="white",
font_size=16,
namelength = 40),
font = dict(
size = 20
)
).update_xaxes(autorange = 'reversed' if x_var == 'Alpha' else None)
return performance_bp
def comparison_boxplot(stats_df: pd.DataFrame, dataset: str, model, metric: str):
'''
Takes a DataFrame a makes a box plot of the selected metric for the purpose of comparing models
Returns a plotly object of different model performances
'''
# Filtering dataframe to desired dataset
stats_df = stats_df[stats_df['Dataset'] == dataset]
# Subsetting scMKL list
subset_modalities = ['RNA - hallmark', 'ATAC - hallmark', 'ATAC_TFIDF - hallmark', 'RNA - all',
'RNA - hallmark', 'ATAC - mvf', 'ATAC - hallmark', 'GENE_SCORES - hallmark']
# Removing genescore for lymphoma MAKE THIS BETTER
if dataset == "lymphoma":
stats_df = stats_df[(stats_df['Modality'] != 'GENE_SCORES - hallmark') & (stats_df['Modality'] != 'GENE_SCORES - all')]
# Filtering dataframe to desired models
stats_df = stats_df[np.isin(stats_df['Model'], model)]
# Filtering scMKL runs to best runs
if 'scMKL' in model:
stats_df = stats_df[(stats_df['Alpha Star'] == 'Yes') | (stats_df['Model'] != 'scMKL')]
stats_df = stats_df[np.isin(stats_df['Modality'], subset_modalities)]
stats_df['Model (Modality)'] = stats_df['Model'] + " (" + stats_df['Modality'] + ")"
# Getting order of lowest to highest performance by model and modality
group_order = stats_df[[metric, 'Model (Modality)']].groupby('Model (Modality)').apply(lambda x: np.mean(x)).sort_values().index
stats_df['Model (Modality)'] = pd.Categorical(stats_df['Model (Modality)'], categories = group_order)
# models_bp = (ggplot(stats_df, aes(x = 'Model (Modality)', y = metric, fill = 'Model', color = "Model"))
# + geom_boxplot()
# + theme_classic()
# + scale_fill_manual(values = {'scMKL' : "#e60b0f", "XGBoost" : "#1411ab", "MLP" : "#11ab1e"})
# + scale_color_manual(values = {'scMKL' : "#e60b0f", "XGBoost" : "#1411ab", "MLP" : "#11ab1e"})
# + theme(axis_text_x=element_text(rotation=90))
# + ggtitle(dataset.capitalize() if len(dataset) > 4 else dataset)
# + theme(axis_text_x= element_text(weight = 'bold', size = 10), axis_text_y= element_text(weight = 'bold'))
# )
# return models_bp.draw()
models_bp = px.box(
data_frame = stats_df,
x = 'Model (Modality)',
y = metric,
color = 'Model',
template = 'plotly_white',
height = 700,
category_orders = {'Model' : ['scMKL', 'XGBoost', 'MLP'],
'Model (Modality)' : group_order},
color_discrete_map = {
'scMKL' : 'red',
'XGBoost' : 'blue',
'MLP' : 'green'
}
).update_traces(width = 0.75,
).update_layout(
hovermode = 'x unified',
hoverlabel=dict(
bgcolor="white",
font_size=16,
namelength = 40),
font = dict(
size = 20
)
)
return models_bp
def plot_umap(umap_dict, modality, dataset, grouping, label, subset):
'''
Takes a dictionary of dict[RNA | ATAC][dataset][Embeddings | Cell labels | Silhouette Score]
Returns a plotly object of UMAP embeddings
'''
if subset == "None":
subset_features = "Most Variable Features"
elif grouping == 'Hallmark':
subset_features = grouping.lower() + '_HALLMARK_' + subset.replace(" ", "_").upper()
elif grouping == 'JASPAR':
subset_features = 'motifs_' + subset
else:
subset_features = grouping.lower() + "_" + subset.replace(" ", "_")
umap_df = pd.DataFrame(umap_dict[modality][dataset][subset_features]['Embeddings'])
umap_df = umap_df.rename(columns = {0 : "UMAP_1", 1 : "UMAP_2", 2 : "UMAP_3"})
umap_df[label] = np.array(umap_dict[modality][dataset][subset_features]["Cell Labels"][label])
# umap_plot = (ggplot(umap_df, aes(x = 'UMAP_1', y = 'UMAP_2', color = label))
# + geom_point(size = 0.75)
# + theme_classic()
# + ggtitle("Silhouette Score: " + str(round(umap_dict[modality][dataset][subset_features]["Silhouette Scores"][label], 3)) if type(umap_dict[modality][dataset][subset_features]["Silhouette Scores"][label]) != str else umap_dict[modality][dataset][subset_features]["Silhouette Scores"][label])
# )
# return umap_plot.draw()
# umap_plot = px.scatter(
# data_frame = umap_df,
# x = 'UMAP_1',
# y = 'UMAP_2',
# color = label,
# template = 'plotly_white',
# ).update_layout(
# hoverlabel=dict(
# font_size=16,
# namelength = 40),
# font = dict(
# size = 20
# )
# )
umap_plot = px.scatter_3d(
data_frame = umap_df,
x = 'UMAP_1',
y = 'UMAP_2',
z = 'UMAP_3',
color = label,
template = 'plotly_white',
height = 650,
).update_layout(
hoverlabel=dict(
font_size=16,
namelength = 40),
# font = dict(
# size = 1
# )
).update_traces(
marker=dict(size=3))
return umap_plot
def weights_boxplot(norm_df: pd.DataFrame, dataset, modality, shown_groups = 9):
'''
norm_df: a dataframe with columns: Group, alpha. norm, mean_weight, log_mean_weights, nonzero, proportion_selected
shown_groups: either a number or list-like object to be displayed in the plot
- if a number, most frequently selected groups are shown
returns a plotly object
'''
modality = format_datatype_grouping(modality)
norm_df = norm_df[(norm_df['Dataset'] == dataset) & (norm_df['Modality'] == modality)]
if type(shown_groups) == int:
rowsums = norm_df.groupby(['Group'], observed = False).sum('Proportion Selected').sort_values('Proportion Selected')
top_groups = np.array(rowsums.index)[-shown_groups:]
norm_df = norm_df[norm_df.Group.isin(top_groups)]
else:
norm_df = norm_df.iloc[np.where(np.isin(norm_df['Group'], shown_groups))[0], :]
# Building a boxplot of normalized weights
# norm_plot = (ggplot(norm_df)
# + geom_boxplot(aes(x = 'Alpha', y = 'Norm', fill = 'Group', group = "Alpha"))
# + scale_x_continuous(breaks = np.unique(norm_df.Alpha))
# + theme(figure_size=(1000,1000), axis_text_x= element_text(weight = 'bold'))
# + theme_classic()
# + guides(fill = False)
# + facet_wrap("Group")
# + ggtitle(dataset.capitalize() if len(dataset) > 4 else dataset))
# return norm_plot.draw()
norm_df['Alpha'] = norm_df['Alpha'].astype(str)
norm_plot = px.box(
data_frame = norm_df,
x = 'Alpha',
y = 'Norm',
color = 'Group',
template = 'plotly_white',
height = 700,
facet_col = 'Group',
facet_col_wrap = 3,
category_orders = {'Group' : top_groups[::-1],
'Alpha' : np.unique(norm_df['Alpha'])[::-1]}
).update_traces(
width = 0.75,
).update_yaxes(title = ''
).update_xaxes(title = ''
).update_layout(
hovermode = 'x unified',
hoverlabel=dict(
bgcolor="white",
font_size=16,
namelength = 40),
font = dict(
size = 20
),
showlegend = False,
yaxis4=dict(title = "Normalized Weight"),
xaxis2 = dict(title = "Alpha")
).for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1].replace("_", " "))
)
return norm_plot
def plot_features(selections_df, dataset, modality):
'''
Takes feature selection_df and returns the selected features for that experiment as a plot
NOTE: if motifs in modality selection, returns None
Returns a plotly object as a bar plot of top selected features
'''
modality = format_datatype_grouping(modality)
if 'motif' in modality:
return None
# Formatting DataFrame
selections_df = selections_df[(selections_df['Dataset'] == dataset) & (selections_df['Modality'] == modality)]
selections_df = selections_df.sort_values(by = 'selection', ascending = True)
selections_df['feature'] = pd.Categorical(selections_df['feature'], categories = selections_df['feature'])
selections_df = selections_df.iloc[(len(selections_df) - 40):len(selections_df), :]
# gf_bar = (ggplot(selections_df, aes(y = 'selection', x = 'feature'))
# + geom_bar(stat = 'identity', fill = "#3268a8")
# + theme_bw()
# + ggtitle('Top 50 Features')
# + xlab('Top Selected Features')
# + ylab('scMKL Selection Frequency')
# + coord_flip()
# + theme(axis_text_y= element_text(weight = 'bold'), axis_text_x= element_text(weight = 'bold'))
# )
# return gf_bar
gf_bar = px.bar(
data_frame = selections_df,
orientation = 'h',
x = 'selection',
y = 'feature',
template = 'plotly_white',
color = 'Number of Groups Feature in',
height = 700,
color_continuous_scale = px.colors.sequential.Bluered,
).update_layout(
xaxis = dict(title = 'Times Selected by scMKL'),
yaxis = dict(title = 'Features'),
font = dict(size = 12),
)
return gf_bar
def create_volcano(vol_df, dataset, modality, grouping, group, grouping_dict):
'''
Takes a processed DataFrame and plots adj. p-value by log(fold_change)
Returns a plotly object
'''
if dataset == "song_prostate":
dataset = 'prostate_rna'
elif dataset == 'prostate':
dataset = 'prostate_atac'
reg_colors = {'Up-regulated' : 'green',
'Down-regulated' : 'red',
'Not significant' : 'blue'}
vol_df = vol_df[vol_df['Dataset'] == dataset]
if "RNA" == modality:
lfc = "logfoldchanges"
label_name = 'names'
modality = "RNA"
adj_pval = 'pvals_adj'
if group != "None":
group = "HALLMARK_" + group.replace(" ", "_").upper()
vol_df = vol_df[np.isin(vol_df['names'], list(grouping_dict[dataset]['RNA'][grouping][group]))]
# vol_plot = (ggplot(vol_df, aes(y = "-log10(adjusted p-val)", x = lfc, color = "Enrichment", label = label_name))
# + geom_point(size = 0.5)
# + theme_classic()
# # + geom_text(data = vol_df[np.isin(vol_df[label_name], selected)] ,
# # size = 8
# # )
# + geom_vline(xintercept = [-0.38, 0.38], linetype = "dotted", color = ['black', 'black'])
# + geom_hline(yintercept = -np.log10(0.05), linetype = "solid", color = 'black')
# + ggtitle(f"{dataset.capitalize() if len(dataset) > 4 else dataset} - {modality}")
# )
else:
lfc = "log2(fold_change)"
label_name = 'feature name'
modality = "ATAC"
adj_pval = 'adjusted p-value'
vol_df['Enrichment'] = vol_df['Enrichment'].apply(lambda x: 'Up-regulated' if 'Up' in x else x)
vol_df['Enrichment'] = vol_df['Enrichment'].apply(lambda x: 'Down-regulated' if 'Down' in x else x)
if group != "None":
if grouping == "Hallmark":
group = "HALLMARK_" + group.upper().replace(" ", "_")
vol_df = vol_df[np.isin(vol_df['feature name'], list(grouping_dict[dataset]['ATAC'][grouping][group]))]
# vol_plot = (ggplot(vol_df, aes(y = "-log10(adjusted p-val)", x = lfc, color = "Enrichment", label = label_name))
# + geom_point(size = 0.5)
# + theme_classic()
# + geom_vline(xintercept = [-0.38, 0.38], linetype = "dotted", color = ['black', 'black'])
# + geom_hline(yintercept = -np.log10(0.05), linetype = "solid", color = 'black')
# + ggtitle(f"{dataset.capitalize() if len(dataset) > 4 else dataset} - {modality}")
# )
# return vol_plot.draw()
vol_plot = px.scatter(
data_frame = vol_df,
x = lfc,
y = '-log10(adjusted p-val)',
color = 'Enrichment',
template = 'plotly_white',
hover_name = label_name,
hover_data = adj_pval,
color_discrete_map = reg_colors,
height = 650,
).update_layout(
hoverlabel=dict(
font_size=16,
namelength = 40),
font = dict(
size = 20
)
)
return vol_plot
def gene_distribution(freq_df):
'''
Takes a DataFrame of genes, number of groups gene is in and returns a distribution of gene frequency in grouping.
Returns a plotly histogram of gene frequencies.
'''
freq_plot = px.histogram(
data_frame = freq_df,
x = 'Number of Sets',
template = 'plotly_white',
color_discrete_sequence = ['blue'],
log_y = True,
title = "Distribution of Hallmark Gene Overlap"
).update_layout(
font = dict(size = 16),
yaxis = dict(title = "log(Counts)"))
return freq_plot
def GO_plot(GO_df, dataset):
'''
Takes gene enrichment DataFrame and returns a horizontal barplot of gene set enrichment for go biological processes.
Returns a plotly barplot object.
'''
GO_df = GO_df[GO_df['Dataset'] == dataset]
GO_df = GO_df.sort_values(by = 'GSE (-log10(adj. p-val))', ascending = False)[0:30].reset_index()
GO_df = GO_df.rename(columns = {"Group Name" : "Gene Sets"})
GO_df['Gene Sets'] = GO_df['Gene Sets'].apply(lambda x: x.split(" (")[0])
GO_fig = px.bar(
data_frame = GO_df,
x = 'GSE (-log10(adj. p-val))',
y = 'Gene Sets',
color_discrete_sequence = ['pink'],
template = 'plotly_white',
category_orders = {'Gene Sets' : GO_df['Gene Sets']},
height = 700,
).update_layout(
yaxis = dict(dtick = 1),
font = dict(size = 16),
)
return GO_fig
def hallmark_genesets_plot(hallmark_df, dataset):
'''
Takes a geneset enrichment barplot for hallmark gene sets and returns gene set enrichment for hallmark gene sets.
Returns a plotly bar plot object.
'''
hallmark_df = hallmark_df[hallmark_df['Dataset'] == dataset]
order_df = hallmark_df[hallmark_df['Variable'] == 'Proportion of DE Features'].copy()
order = order_df.sort_values(by = 'Value', ascending = False)['Group']
order = order.tolist()
hallmark_plot = px.bar(
data_frame = hallmark_df,
orientation = 'h',
x = 'Value',
y = 'Group',
facet_col = 'Variable',
color = 'Variable',
template = 'plotly_white',
height = 900,
category_orders = {'Variable' : ['Proportion of DE Features', 'Gene Set Enrichment (-log10(adjusted p-value))', 'scMKL Selection Frequency'],
'Group' : order},
hover_name = 'Group',
color_discrete_sequence = ['blue', "orange", "red"]
).update_layout(
yaxis = dict(title = 'Gene Sets', dtick = 1),
font = dict(size = 16),
xaxis1 = dict(title = "Proportion of DEG Overlap with Hallmark Gene Sets"),
xaxis2 = dict(title = "-log10(adjuseted p-value)"),
xaxis3 = dict(title = "Times selected by scMKL"),
showlegend = False
).update_xaxes(matches=None
).for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1].replace("_", " "))
)
return hallmark_plot