Spaces:
Running
Running
#%% | |
import argparse | |
import os | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import pkg_resources | |
from proscope.data import get_genename_to_uniprot, get_lddt, get_seq | |
import pandas as pd | |
from dash_bio import Clustergram | |
seq = get_seq() | |
genename_to_uniprot = get_genename_to_uniprot() | |
lddt = get_lddt() | |
import sys | |
from glob import glob | |
import numpy as np | |
from atac_rna_data_processing.config.load_config import load_config | |
from atac_rna_data_processing.io.celltype import GETCellType | |
from atac_rna_data_processing.io.nr_motif_v1 import NrMotifV1 | |
from proscope.af2 import AFPairseg | |
from proscope.protein import Protein | |
from proscope.viewer import view_pdb_html | |
#%% | |
args = argparse.ArgumentParser() | |
args.add_argument("-p", "--port", type=int, default=7860, help="Port number") | |
args.add_argument("-s", "--share", action="store_true", help="Share on network") | |
args.add_argument("-d", "--data", type=str, default="/data", help="Data directory") | |
# args = args.parse_args() | |
# set pseudo args | |
args = args.parse_args(['-p', '7869', '-s', '-d', '/manitou/pmg/users/xf2217/demo_data']) | |
#%% | |
gene_pairs = glob(f"{args.data}/structures/causal/*") | |
gene_pairs = [os.path.basename(pair) for pair in gene_pairs] | |
GET_CONFIG = load_config('/manitou/pmg/users/xf2217/atac_rna_data_processing/atac_rna_data_processing/config/GET') | |
GET_CONFIG.celltype.jacob=True | |
GET_CONFIG.celltype.num_cls=2 | |
GET_CONFIG.celltype.input=True | |
GET_CONFIG.celltype.embed=True | |
GET_CONFIG.celltype.data_dir = '/manitou/pmg/users/xf2217/pretrain_human_bingren_shendure_apr2023/fetal_adult/' | |
GET_CONFIG.celltype.interpret_dir='/manitou/pmg/users/xf2217/Interpretation_all_hg38_allembed_v4_natac/' | |
GET_CONFIG.motif_dir = '/manitou/pmg/users/xf2217/interpret_natac/motif-clustering' | |
motif = NrMotifV1.load_from_pickle( | |
pkg_resources.resource_filename("atac_rna_data_processing", "data/NrMotifV1.pkl"), | |
GET_CONFIG.motif_dir | |
) | |
cell_type_annot = pd.read_csv(GET_CONFIG.celltype.data_dir.split('fetal_adult')[0] + 'data/cell_type_pretrain_human_bingren_shendure_apr2023.txt') | |
cell_type_id_to_name = dict(zip(cell_type_annot['id'], cell_type_annot['celltype'])) | |
cell_type_name_to_id = dict(zip(cell_type_annot['celltype'], cell_type_annot['id'])) | |
avaliable_celltypes = sorted([cell_type_id_to_name[f.split('/')[-1]] for f in glob(GET_CONFIG.celltype.interpret_dir+'*')]) | |
#%% | |
# fill this in... | |
# set plot ppi to 100 | |
plt.rcParams['figure.dpi'] = 100 | |
def visualize_AF2(tf_pair, a): | |
strcture_dir = f"{args.data}/structures/causal/{tf_pair}" | |
fasta_dir = f"{args.data}/sequences/causal/{tf_pair}" | |
if not os.path.exists(strcture_dir): | |
gr.ErrorText("No such gene pair") | |
a = AFPairseg(strcture_dir, fasta_dir) | |
segpair.choices = list(a.pairs_data.keys()) | |
fig1, ax1 = a.plot_plddt_gene1() | |
fig2, ax2 = a.plot_plddt_gene2() | |
fig3, ax3 = a.protein1.plot_plddt() | |
fig4, ax4 = a.protein2.plot_plddt() | |
fig5, ax5 = a.plot_score_heatmap() | |
plt.tight_layout() | |
new_dropdown = update_dropdown(list(a.pairs_data.keys()), 'Segment pair') | |
return fig1, fig2, fig3, fig4, fig5, new_dropdown, a | |
def view_pdb(seg_pair, a): | |
pdb_path = a.pairs_data[seg_pair].pdb | |
return view_pdb_html(pdb_path), a, pdb_path | |
def update_dropdown(x, label): | |
return gr.Dropdown.update(choices=x, label=label) | |
def load_and_plot_celltype(celltype_name, GET_CONFIG, cell): | |
celltype_id = cell_type_name_to_id[celltype_name] | |
cell = GETCellType(celltype_id, GET_CONFIG) | |
cell.celltype_name = celltype_name | |
gene_exp_fig = cell.plotly_gene_exp() | |
gene_exp_table = cell.gene_annot.groupby('gene_name')[['pred', 'obs', 'accessibility']].mean().reset_index() | |
return gene_exp_fig, gene_exp_table, cell | |
def plot_gene_regions(cell, gene_name, plotly=True): | |
return cell.plot_gene_regions(gene_name, plotly=plotly), cell | |
def plot_gene_motifs(cell, gene_name, motif, overwrite=False): | |
return cell.plot_gene_motifs(gene_name, motif, overwrite=overwrite)[0], cell | |
def plot_motif_subnet(cell, motif_collection, m, type='neighbors', threshold=0.1): | |
return cell.plotly_motif_subnet(motif_collection, m, type=type, threshold=threshold), cell | |
def plot_gene_exp(cell, plotly=True): | |
return cell.plotly_gene_exp(plotly=plotly), cell | |
def plot_motif_corr(cell): | |
fig = Clustergram(data=cell.gene_by_motif.corr, | |
column_labels=list(cell.gene_by_motif.corr.columns.values), | |
row_labels=list(cell.gene_by_motif.corr.index), | |
hidden_labels=['row', 'col'], | |
link_method='average', | |
display_ratio=0.1, | |
width=600, | |
height=400, | |
color_map='rdbu_r', | |
) | |
return fig, cell | |
#%% | |
# fill this in... | |
# main | |
if __name__ == '__main__': | |
with gr.Blocks(theme='sudeepshouche/minimalist') as demo: | |
seg_pairs = gr.State(['']) | |
af = gr.State(None) | |
cell = gr.State(None) | |
with gr.Row() as row: | |
# Left column: Plot gene expression and gene regions | |
with gr.Column(): | |
with gr.Row() as row: | |
celltype_name = gr.Dropdown(label='Cell Type', choices=avaliable_celltypes) | |
celltype_btn = gr.Button(value='Load & Plot Gene Expression') | |
gene_exp_plot = gr.Plot(label='Gene Expression Pred vs Obs') | |
gene_exp_table = gr.DataFrame(label='Gene Expression Table', max_rows=10) | |
# Right column: Plot gene motifs | |
with gr.Column(): | |
gene_name_for_region = gr.Textbox(label='Get important regions or motifs for gene:') | |
with gr.Row() as row: | |
region_plot_btn = gr.Button(value='Regions') | |
motif_plot_btn = gr.Button(value='Motifs') | |
region_plot = gr.Plot(label='Gene Regions') | |
motif_plot = gr.Plot(label='Gene Motifs') | |
with gr.Row() as row: | |
with gr.Column(): | |
clustergram_btn = gr.Button(value='Plot Motif Correlation Heatmap') | |
clustergram_plot = gr.Plot(label='Motif Correlation') | |
# Right column: Motif subnet plot | |
with gr.Column(): | |
with gr.Row() as row: | |
motif_for_subnet = gr.Dropdown(label='Motif Causal Subnetwork', choices=motif.cluster_names) | |
subnet_type = gr.Dropdown(label='Type', choices=['neighbors', 'parents', 'children'], default='neighbors') | |
# slider for threshold 0.01-0.2 | |
subnet_threshold = gr.Slider(label='Threshold', minimum=0.01, maximum=0.25, step=0.01, value=0.1) | |
subnet_btn = gr.Button(value='Plot Motif Causal Subnetwork') | |
subnet_plot = gr.Plot(label='Motif Causal Subnetwork') | |
with gr.Row() as row: | |
with gr.Column(): | |
with gr.Row() as row: | |
tf_pairs = gr.Dropdown(label='TF pair', choices=gene_pairs) | |
tf_pairs_btn = gr.Button(value='Load & Plot') | |
interact_plddt1 = gr.Plot(label='Interact pLDDT 1') | |
interact_plddt2 = gr.Plot(label='Interact pLDDT 2') | |
protein1_plddt = gr.Plot(label='Protein 1 pLDDT') | |
protein2_plddt = gr.Plot(label='Protein 2 pLDDT') | |
heatmap = gr.Plot(label='Heatmap') | |
with gr.Column(): | |
with gr.Row() as row: | |
segpair = gr.Dropdown(label='Seg pair', choices=seg_pairs.value) | |
segpair_btn = gr.Button(value='Get PDB') | |
pdb_html = gr.HTML(label="PDB HTML") | |
pdb_file = gr.File(label='Download PDB') | |
tf_pairs_btn.click(visualize_AF2, inputs = [tf_pairs, af], outputs = [ interact_plddt1, interact_plddt2, protein1_plddt, protein2_plddt, heatmap, segpair, af]) | |
segpair_btn.click(view_pdb, inputs=[segpair, af], outputs=[pdb_html, af, pdb_file]) | |
celltype_btn.click(load_and_plot_celltype, inputs=[celltype_name, gr.State(GET_CONFIG), cell], outputs=[gene_exp_plot, gene_exp_table, cell]) | |
region_plot_btn.click(plot_gene_regions, inputs=[cell, gene_name_for_region], outputs=[region_plot, cell]) | |
motif_plot_btn.click(plot_gene_motifs, inputs=[cell, gene_name_for_region, gr.State(motif)], outputs=[motif_plot, cell]) | |
clustergram_btn.click(plot_motif_corr, inputs=[cell], outputs=[clustergram_plot, cell]) | |
subnet_btn.click(plot_motif_subnet, inputs=[cell, gr.State(motif), motif_for_subnet, subnet_type, subnet_threshold], outputs=[subnet_plot, cell]) | |
demo.launch(share=args.share, server_port=args.port) | |
# %% | |