Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
import pandas as pd | |
import numpy as np | |
import plotly.express as px | |
def get_best_alpha(stats_df, modality): | |
''' | |
Takes a DataFrame of scMKL results and returns the alpha with the best mean AUROC | |
stats_df: a DataFrame | |
modality: the modality to find the best alpha for | |
Returns best alpha for modality | |
''' | |
best_alpha["None", "Estrogen Response Early", "Estrogen Response Late", "Protein Secretion", "E2F Targets", "TGF Beta Signaling", "Apical Surface"] = stats_df[(stats_df['Model'] == 'scMKL') & (stats_df['Modality'] == modality)][['Alpha', 'AUROC']].groupby('Alpha')['AUROC'].apply(lambda x: np.mean(x)) | |
best_alpha = best_alpha[best_alpha == np.max(alpha_star)].index[0] | |
return best_alpha | |
def format_datatype_grouping(dtype_grouping): | |
''' | |
Takes either a list | tuple | str and formats the names to match labels in dataframes | |
Returns formatted names as list or str | |
''' | |
if (type(dtype_grouping) == list) or (type(dtype_grouping) == tuple): | |
formatted_data = [selection.replace("Hallmark", "hallmark").replace("Cistrome", "cistrome").replace("Motifs", "motifs").replace("Neuronal","neuronal") for selection in dtype_grouping] | |
else: | |
formatted_data = dtype_grouping.replace("Hallmark", "hallmark").replace("Cistrome", "cistrome").replace("Motifs", "motifs").replace("Neuronal","neuronal") | |
return formatted_data | |
def performance_boxplot(stats_df: pd.DataFrame, dataset: str, modality, metric: str, x_flag = "intersect", x_var = 'Alpha', color_dict = None): | |
''' | |
This function will plot a given metric for a given dataset. | |
stats_df: a DataFrame with columns | |
dataset: MCF7, T47D, lymphoma, prostate | |
modality: which modality or modalities should be visualized | |
metric: which metric should be displayed | |
Returns a plotnine object | |
''' | |
# Formatting modality list | |
modality = format_datatype_grouping(modality) | |
# Filtering data frame to desired dataset and modality(s) | |
stats_df = stats_df[(stats_df['Dataset'] == dataset) & (np.isin(stats_df['Modality'], modality)) & (stats_df['Model'] == 'scMKL')] | |
if ((type(modality) is list) or (type(modality) is tuple)) and (x_flag == "intersect"): | |
x_list = np.unique(stats_df[x_var]) | |
for i, mod in enumerate(modality): | |
x_list = [value for value in x_list if value in np.unique(stats_df[stats_df['Modality'] == mod][x_var])] | |
stats_df = stats_df[np.isin(stats_df[x_var], x_list)] | |
if x_flag == 'best': | |
stats_df = stats_df[stats_df['Alpha Star'] == 'Yes'] | |
modality_alpha_means = {mod : round(np.mean(stats_df[stats_df['Modality'] == mod]['Alpha']), 3) for mod in np.unique(stats_df['Modality'])} | |
stats_df['Mean Alpha Star'] = stats_df['Modality'].apply(lambda x: modality_alpha_means[x]) | |
x_var = 'Mean Alpha Star' if x_var == 'Alpha' else x_var | |
if x_var == 'Mean_Number_of_Selected_Groups': | |
for mod in modality: | |
stats_df.loc[stats_df['Modality'] == mod, 'Mean_Number_of_Selected_Groups'] = np.mean(stats_df[stats_df['Modality'] == mod]['Number_of_Selected_Groups']) | |
# Making x_var catagorical for plotting | |
if (metric == 'RAM_usage') or (metric == 'Inference_time'): | |
x_var = 'Modality' | |
else: | |
stats_df = stats_df.sort_values(by = x_var) | |
stats_df[x_var] = pd.Categorical(stats_df[x_var], categories = np.unique(stats_df[x_var])) if 'Alpha' not in x_var else pd.Categorical(stats_df[x_var], categories = np.unique(stats_df[x_var])[::-1]) | |
# performance_bp = (ggplot(stats_df, aes(x = x_var, y = metric, fill = 'Modality', label = 'Modality', color = 'Modality')) | |
# + geom_boxplot() | |
# + theme_classic() | |
# # + scale_fill_manual(values = {'ATAC - cistrome' : '#2e61a3', 'ATAC - hallmark' : '#323aa8', 'ATAC - motifs' : "#05426e", | |
# # 'ATAC_TFIDF - cistrome' : '#32b3b8', 'ATAC_TFIDF - hallmark' : '#349eeb', | |
# # 'RNA - hallmark' : '#b52a3c', | |
# # 'GENE SCORES - hallmark' : '#11bd50'},) | |
# + theme(axis_text_x=element_text(rotation=90)) | |
# + ggtitle(dataset.capitalize() if len(dataset) > 4 else dataset) | |
# + theme(axis_text_x= element_text(weight = 'bold', size = 10), axis_text_y= element_text(weight = 'bold')) | |
# # + geom_text() | |
# # + geom_text(aes(label=after_stat(stats_df['Modality'])), stat="identity", nudge_y=0.125, va="bottom") | |
# ) | |
# return performance_bp.draw() | |
if x_var != 'Modality': | |
max_x = max(np.unique(stats_df[x_var])) | |
min_x = min(np.unique(stats_df[x_var])) | |
range_x = max_x - min_x | |
width_x = range_x * 0.02 | |
else: | |
width_x = None | |
performance_bp = px.box( | |
data_frame = stats_df, | |
x = x_var, | |
y = metric, | |
color = 'Modality', | |
template = 'plotly_white', | |
height = 800, | |
hover_name = 'Modality', | |
category_orders = {'Modality' : modality}, | |
color_discrete_map = color_dict | |
).update_traces(width = width_x, | |
).update_layout( | |
hovermode = 'x unified', | |
hoverlabel=dict( | |
bgcolor="white", | |
font_size=16, | |
namelength = 40), | |
font = dict( | |
size = 20 | |
) | |
).update_xaxes(autorange = 'reversed' if x_var == 'Alpha' else None) | |
return performance_bp | |
def comparison_boxplot(stats_df: pd.DataFrame, dataset: str, model, metric: str): | |
''' | |
Takes a DataFrame a makes a box plot of the selected metric for the purpose of comparing models | |
Returns a plotly object of different model performances | |
''' | |
# Filtering dataframe to desired dataset | |
stats_df = stats_df[stats_df['Dataset'] == dataset] | |
# Subsetting scMKL list | |
subset_modalities = ['RNA - hallmark', 'ATAC - hallmark', 'ATAC_TFIDF - hallmark', 'RNA - all', | |
'RNA - hallmark', 'ATAC - mvf', 'ATAC - hallmark', 'GENE_SCORES - hallmark'] | |
# Removing genescore for lymphoma MAKE THIS BETTER | |
if dataset == "lymphoma": | |
stats_df = stats_df[(stats_df['Modality'] != 'GENE_SCORES - hallmark') & (stats_df['Modality'] != 'GENE_SCORES - all')] | |
# Filtering dataframe to desired models | |
stats_df = stats_df[np.isin(stats_df['Model'], model)] | |
# Filtering scMKL runs to best runs | |
if 'scMKL' in model: | |
stats_df = stats_df[(stats_df['Alpha Star'] == 'Yes') | (stats_df['Model'] != 'scMKL')] | |
stats_df = stats_df[np.isin(stats_df['Modality'], subset_modalities)] | |
stats_df['Model (Modality)'] = stats_df['Model'] + " (" + stats_df['Modality'] + ")" | |
# Getting order of lowest to highest performance by model and modality | |
group_order = stats_df[[metric, 'Model (Modality)']].groupby('Model (Modality)').apply(lambda x: np.mean(x)).sort_values().index | |
stats_df['Model (Modality)'] = pd.Categorical(stats_df['Model (Modality)'], categories = group_order) | |
# models_bp = (ggplot(stats_df, aes(x = 'Model (Modality)', y = metric, fill = 'Model', color = "Model")) | |
# + geom_boxplot() | |
# + theme_classic() | |
# + scale_fill_manual(values = {'scMKL' : "#e60b0f", "XGBoost" : "#1411ab", "MLP" : "#11ab1e"}) | |
# + scale_color_manual(values = {'scMKL' : "#e60b0f", "XGBoost" : "#1411ab", "MLP" : "#11ab1e"}) | |
# + theme(axis_text_x=element_text(rotation=90)) | |
# + ggtitle(dataset.capitalize() if len(dataset) > 4 else dataset) | |
# + theme(axis_text_x= element_text(weight = 'bold', size = 10), axis_text_y= element_text(weight = 'bold')) | |
# ) | |
# return models_bp.draw() | |
models_bp = px.box( | |
data_frame = stats_df, | |
x = 'Model (Modality)', | |
y = metric, | |
color = 'Model', | |
template = 'plotly_white', | |
height = 700, | |
category_orders = {'Model' : ['scMKL', 'XGBoost', 'MLP'], | |
'Model (Modality)' : group_order}, | |
color_discrete_map = { | |
'scMKL' : 'red', | |
'XGBoost' : 'blue', | |
'MLP' : 'green' | |
} | |
).update_traces(width = 0.75, | |
).update_layout( | |
hovermode = 'x unified', | |
hoverlabel=dict( | |
bgcolor="white", | |
font_size=16, | |
namelength = 40), | |
font = dict( | |
size = 20 | |
) | |
) | |
return models_bp | |
def plot_umap(umap_dict, modality, dataset, grouping, label, subset): | |
''' | |
Takes a dictionary of dict[RNA | ATAC][dataset][Embeddings | Cell labels | Silhouette Score] | |
Returns a plotly object of UMAP embeddings | |
''' | |
if subset == "None": | |
subset_features = "Most Variable Features" | |
elif grouping == 'Hallmark': | |
subset_features = grouping.lower() + '_HALLMARK_' + subset.replace(" ", "_").upper() | |
elif grouping == 'JASPAR': | |
subset_features = 'motifs_' + subset | |
else: | |
subset_features = grouping.lower() + "_" + subset.replace(" ", "_") | |
umap_df = pd.DataFrame(umap_dict[modality][dataset][subset_features]['Embeddings']) | |
umap_df = umap_df.rename(columns = {0 : "UMAP_1", 1 : "UMAP_2", 2 : "UMAP_3"}) | |
umap_df[label] = np.array(umap_dict[modality][dataset][subset_features]["Cell Labels"][label]) | |
# umap_plot = (ggplot(umap_df, aes(x = 'UMAP_1', y = 'UMAP_2', color = label)) | |
# + geom_point(size = 0.75) | |
# + theme_classic() | |
# + ggtitle("Silhouette Score: " + str(round(umap_dict[modality][dataset][subset_features]["Silhouette Scores"][label], 3)) if type(umap_dict[modality][dataset][subset_features]["Silhouette Scores"][label]) != str else umap_dict[modality][dataset][subset_features]["Silhouette Scores"][label]) | |
# ) | |
# return umap_plot.draw() | |
# umap_plot = px.scatter( | |
# data_frame = umap_df, | |
# x = 'UMAP_1', | |
# y = 'UMAP_2', | |
# color = label, | |
# template = 'plotly_white', | |
# ).update_layout( | |
# hoverlabel=dict( | |
# font_size=16, | |
# namelength = 40), | |
# font = dict( | |
# size = 20 | |
# ) | |
# ) | |
umap_plot = px.scatter_3d( | |
data_frame = umap_df, | |
x = 'UMAP_1', | |
y = 'UMAP_2', | |
z = 'UMAP_3', | |
color = label, | |
template = 'plotly_white', | |
height = 650, | |
).update_layout( | |
hoverlabel=dict( | |
font_size=16, | |
namelength = 40), | |
# font = dict( | |
# size = 1 | |
# ) | |
).update_traces( | |
marker=dict(size=3)) | |
return umap_plot | |
def weights_boxplot(norm_df: pd.DataFrame, dataset, modality, shown_groups = 9): | |
''' | |
norm_df: a dataframe with columns: Group, alpha. norm, mean_weight, log_mean_weights, nonzero, proportion_selected | |
shown_groups: either a number or list-like object to be displayed in the plot | |
- if a number, most frequently selected groups are shown | |
returns a plotly object | |
''' | |
modality = format_datatype_grouping(modality) | |
norm_df = norm_df[(norm_df['Dataset'] == dataset) & (norm_df['Modality'] == modality)] | |
if type(shown_groups) == int: | |
rowsums = norm_df.groupby(['Group'], observed = False).sum('Proportion Selected').sort_values('Proportion Selected') | |
top_groups = np.array(rowsums.index)[-shown_groups:] | |
norm_df = norm_df[norm_df.Group.isin(top_groups)] | |
else: | |
norm_df = norm_df.iloc[np.where(np.isin(norm_df['Group'], shown_groups))[0], :] | |
# Building a boxplot of normalized weights | |
# norm_plot = (ggplot(norm_df) | |
# + geom_boxplot(aes(x = 'Alpha', y = 'Norm', fill = 'Group', group = "Alpha")) | |
# + scale_x_continuous(breaks = np.unique(norm_df.Alpha)) | |
# + theme(figure_size=(1000,1000), axis_text_x= element_text(weight = 'bold')) | |
# + theme_classic() | |
# + guides(fill = False) | |
# + facet_wrap("Group") | |
# + ggtitle(dataset.capitalize() if len(dataset) > 4 else dataset)) | |
# return norm_plot.draw() | |
norm_df['Alpha'] = norm_df['Alpha'].astype(str) | |
norm_plot = px.box( | |
data_frame = norm_df, | |
x = 'Alpha', | |
y = 'Norm', | |
color = 'Group', | |
template = 'plotly_white', | |
height = 700, | |
facet_col = 'Group', | |
facet_col_wrap = 3, | |
category_orders = {'Group' : top_groups[::-1], | |
'Alpha' : np.unique(norm_df['Alpha'])[::-1]} | |
).update_traces( | |
width = 0.75, | |
).update_yaxes(title = '' | |
).update_xaxes(title = '' | |
).update_layout( | |
hovermode = 'x unified', | |
hoverlabel=dict( | |
bgcolor="white", | |
font_size=16, | |
namelength = 40), | |
font = dict( | |
size = 20 | |
), | |
showlegend = False, | |
yaxis4=dict(title = "Normalized Weight"), | |
xaxis2 = dict(title = "Alpha") | |
).for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1].replace("_", " ")) | |
) | |
return norm_plot | |
def plot_features(selections_df, dataset, modality): | |
''' | |
Takes feature selection_df and returns the selected features for that experiment as a plot | |
NOTE: if motifs in modality selection, returns None | |
Returns a plotly object as a bar plot of top selected features | |
''' | |
modality = format_datatype_grouping(modality) | |
if 'motif' in modality: | |
return None | |
# Formatting DataFrame | |
selections_df = selections_df[(selections_df['Dataset'] == dataset) & (selections_df['Modality'] == modality)] | |
selections_df = selections_df.sort_values(by = 'selection', ascending = True) | |
selections_df['feature'] = pd.Categorical(selections_df['feature'], categories = selections_df['feature']) | |
selections_df = selections_df.iloc[(len(selections_df) - 40):len(selections_df), :] | |
# gf_bar = (ggplot(selections_df, aes(y = 'selection', x = 'feature')) | |
# + geom_bar(stat = 'identity', fill = "#3268a8") | |
# + theme_bw() | |
# + ggtitle('Top 50 Features') | |
# + xlab('Top Selected Features') | |
# + ylab('scMKL Selection Frequency') | |
# + coord_flip() | |
# + theme(axis_text_y= element_text(weight = 'bold'), axis_text_x= element_text(weight = 'bold')) | |
# ) | |
# return gf_bar | |
gf_bar = px.bar( | |
data_frame = selections_df, | |
orientation = 'h', | |
x = 'selection', | |
y = 'feature', | |
template = 'plotly_white', | |
color = 'Number of Groups Feature in', | |
height = 700, | |
color_continuous_scale = px.colors.sequential.Bluered, | |
).update_layout( | |
xaxis = dict(title = 'Times Selected by scMKL'), | |
yaxis = dict(title = 'Features'), | |
font = dict(size = 12), | |
) | |
return gf_bar | |
def create_volcano(vol_df, dataset, modality, grouping, group, grouping_dict): | |
''' | |
Takes a processed DataFrame and plots adj. p-value by log(fold_change) | |
Returns a plotly object | |
''' | |
if dataset == "song_prostate": | |
dataset = 'prostate_rna' | |
elif dataset == 'prostate': | |
dataset = 'prostate_atac' | |
reg_colors = {'Up-regulated' : 'green', | |
'Down-regulated' : 'red', | |
'Not significant' : 'blue'} | |
vol_df = vol_df[vol_df['Dataset'] == dataset] | |
if "RNA" == modality: | |
lfc = "logfoldchanges" | |
label_name = 'names' | |
modality = "RNA" | |
adj_pval = 'pvals_adj' | |
if group != "None": | |
group = "HALLMARK_" + group.replace(" ", "_").upper() | |
vol_df = vol_df[np.isin(vol_df['names'], list(grouping_dict[dataset]['RNA'][grouping][group]))] | |
# vol_plot = (ggplot(vol_df, aes(y = "-log10(adjusted p-val)", x = lfc, color = "Enrichment", label = label_name)) | |
# + geom_point(size = 0.5) | |
# + theme_classic() | |
# # + geom_text(data = vol_df[np.isin(vol_df[label_name], selected)] , | |
# # size = 8 | |
# # ) | |
# + geom_vline(xintercept = [-0.38, 0.38], linetype = "dotted", color = ['black', 'black']) | |
# + geom_hline(yintercept = -np.log10(0.05), linetype = "solid", color = 'black') | |
# + ggtitle(f"{dataset.capitalize() if len(dataset) > 4 else dataset} - {modality}") | |
# ) | |
else: | |
lfc = "log2(fold_change)" | |
label_name = 'feature name' | |
modality = "ATAC" | |
adj_pval = 'adjusted p-value' | |
vol_df['Enrichment'] = vol_df['Enrichment'].apply(lambda x: 'Up-regulated' if 'Up' in x else x) | |
vol_df['Enrichment'] = vol_df['Enrichment'].apply(lambda x: 'Down-regulated' if 'Down' in x else x) | |
if group != "None": | |
if grouping == "Hallmark": | |
group = "HALLMARK_" + group.upper().replace(" ", "_") | |
vol_df = vol_df[np.isin(vol_df['feature name'], list(grouping_dict[dataset]['ATAC'][grouping][group]))] | |
# vol_plot = (ggplot(vol_df, aes(y = "-log10(adjusted p-val)", x = lfc, color = "Enrichment", label = label_name)) | |
# + geom_point(size = 0.5) | |
# + theme_classic() | |
# + geom_vline(xintercept = [-0.38, 0.38], linetype = "dotted", color = ['black', 'black']) | |
# + geom_hline(yintercept = -np.log10(0.05), linetype = "solid", color = 'black') | |
# + ggtitle(f"{dataset.capitalize() if len(dataset) > 4 else dataset} - {modality}") | |
# ) | |
# return vol_plot.draw() | |
vol_plot = px.scatter( | |
data_frame = vol_df, | |
x = lfc, | |
y = '-log10(adjusted p-val)', | |
color = 'Enrichment', | |
template = 'plotly_white', | |
hover_name = label_name, | |
hover_data = adj_pval, | |
color_discrete_map = reg_colors, | |
height = 650, | |
).update_layout( | |
hoverlabel=dict( | |
font_size=16, | |
namelength = 40), | |
font = dict( | |
size = 20 | |
) | |
) | |
return vol_plot | |
def gene_distribution(freq_df): | |
''' | |
Takes a DataFrame of genes, number of groups gene is in and returns a distribution of gene frequency in grouping. | |
Returns a plotly histogram of gene frequencies. | |
''' | |
freq_plot = px.histogram( | |
data_frame = freq_df, | |
x = 'Number of Sets', | |
template = 'plotly_white', | |
color_discrete_sequence = ['blue'], | |
log_y = True, | |
title = "Distribution of Hallmark Gene Overlap" | |
).update_layout( | |
font = dict(size = 16), | |
yaxis = dict(title = "log(Counts)")) | |
return freq_plot | |
def GO_plot(GO_df, dataset): | |
''' | |
Takes gene enrichment DataFrame and returns a horizontal barplot of gene set enrichment for go biological processes. | |
Returns a plotly barplot object. | |
''' | |
GO_df = GO_df[GO_df['Dataset'] == dataset] | |
GO_df = GO_df.sort_values(by = 'GSE (-log10(adj. p-val))', ascending = False)[0:30].reset_index() | |
GO_df = GO_df.rename(columns = {"Group Name" : "Gene Sets"}) | |
GO_df['Gene Sets'] = GO_df['Gene Sets'].apply(lambda x: x.split(" (")[0]) | |
GO_fig = px.bar( | |
data_frame = GO_df, | |
x = 'GSE (-log10(adj. p-val))', | |
y = 'Gene Sets', | |
color_discrete_sequence = ['pink'], | |
template = 'plotly_white', | |
category_orders = {'Gene Sets' : GO_df['Gene Sets']}, | |
height = 700, | |
).update_layout( | |
yaxis = dict(dtick = 1), | |
font = dict(size = 16), | |
) | |
return GO_fig | |
def hallmark_genesets_plot(hallmark_df, dataset): | |
''' | |
Takes a geneset enrichment barplot for hallmark gene sets and returns gene set enrichment for hallmark gene sets. | |
Returns a plotly bar plot object. | |
''' | |
hallmark_df = hallmark_df[hallmark_df['Dataset'] == dataset] | |
order_df = hallmark_df[hallmark_df['Variable'] == 'Proportion of DE Features'].copy() | |
order = order_df.sort_values(by = 'Value', ascending = False)['Group'] | |
order = order.tolist() | |
hallmark_plot = px.bar( | |
data_frame = hallmark_df, | |
orientation = 'h', | |
x = 'Value', | |
y = 'Group', | |
facet_col = 'Variable', | |
color = 'Variable', | |
template = 'plotly_white', | |
height = 900, | |
category_orders = {'Variable' : ['Proportion of DE Features', 'Gene Set Enrichment (-log10(adjusted p-value))', 'scMKL Selection Frequency'], | |
'Group' : order}, | |
hover_name = 'Group', | |
color_discrete_sequence = ['blue', "orange", "red"] | |
).update_layout( | |
yaxis = dict(title = 'Gene Sets', dtick = 1), | |
font = dict(size = 16), | |
xaxis1 = dict(title = "Proportion of DEG Overlap with Hallmark Gene Sets"), | |
xaxis2 = dict(title = "-log10(adjuseted p-value)"), | |
xaxis3 = dict(title = "Times selected by scMKL"), | |
showlegend = False | |
).update_xaxes(matches=None | |
).for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1].replace("_", " ")) | |
) | |
return hallmark_plot |