#%% import argparse import os import gradio as gr import matplotlib.pyplot as plt import pkg_resources from proscope.data import get_genename_to_uniprot, get_lddt, get_seq import pandas as pd from dash_bio import Clustergram seq = get_seq() genename_to_uniprot = get_genename_to_uniprot() lddt = get_lddt() import sys from glob import glob import numpy as np from atac_rna_data_processing.config.load_config import load_config from atac_rna_data_processing.io.celltype import GETCellType from atac_rna_data_processing.io.nr_motif_v1 import NrMotifV1 from proscope.af2 import AFPairseg from proscope.protein import Protein from proscope.viewer import view_pdb_html #%% args = argparse.ArgumentParser() args.add_argument("-p", "--port", type=int, default=7860, help="Port number") args.add_argument("-s", "--share", action="store_true", help="Share on network") args.add_argument("-d", "--data", type=str, default="/data", help="Data directory") # args = args.parse_args() # set pseudo args args = args.parse_args(['-p', '7869', '-s', '-d', '/manitou/pmg/users/xf2217/demo_data']) #%% gene_pairs = glob(f"{args.data}/structures/causal/*") gene_pairs = [os.path.basename(pair) for pair in gene_pairs] GET_CONFIG = load_config('/manitou/pmg/users/xf2217/atac_rna_data_processing/atac_rna_data_processing/config/GET') GET_CONFIG.celltype.jacob=True GET_CONFIG.celltype.num_cls=2 GET_CONFIG.celltype.input=True GET_CONFIG.celltype.embed=True GET_CONFIG.celltype.data_dir = '/manitou/pmg/users/xf2217/pretrain_human_bingren_shendure_apr2023/fetal_adult/' GET_CONFIG.celltype.interpret_dir='/manitou/pmg/users/xf2217/Interpretation_all_hg38_allembed_v4_natac/' GET_CONFIG.motif_dir = '/manitou/pmg/users/xf2217/interpret_natac/motif-clustering' motif = NrMotifV1.load_from_pickle( pkg_resources.resource_filename("atac_rna_data_processing", "data/NrMotifV1.pkl"), GET_CONFIG.motif_dir ) cell_type_annot = pd.read_csv(GET_CONFIG.celltype.data_dir.split('fetal_adult')[0] + 'data/cell_type_pretrain_human_bingren_shendure_apr2023.txt') cell_type_id_to_name = dict(zip(cell_type_annot['id'], cell_type_annot['celltype'])) cell_type_name_to_id = dict(zip(cell_type_annot['celltype'], cell_type_annot['id'])) avaliable_celltypes = sorted([cell_type_id_to_name[f.split('/')[-1]] for f in glob(GET_CONFIG.celltype.interpret_dir+'*')]) #%% # fill this in... # set plot ppi to 100 plt.rcParams['figure.dpi'] = 100 def visualize_AF2(tf_pair, a): strcture_dir = f"{args.data}/structures/causal/{tf_pair}" fasta_dir = f"{args.data}/sequences/causal/{tf_pair}" if not os.path.exists(strcture_dir): gr.ErrorText("No such gene pair") a = AFPairseg(strcture_dir, fasta_dir) segpair.choices = list(a.pairs_data.keys()) fig1, ax1 = a.plot_plddt_gene1() fig2, ax2 = a.plot_plddt_gene2() fig3, ax3 = a.protein1.plot_plddt() fig4, ax4 = a.protein2.plot_plddt() fig5, ax5 = a.plot_score_heatmap() plt.tight_layout() new_dropdown = update_dropdown(list(a.pairs_data.keys()), 'Segment pair') return fig1, fig2, fig3, fig4, fig5, new_dropdown, a def view_pdb(seg_pair, a): pdb_path = a.pairs_data[seg_pair].pdb return view_pdb_html(pdb_path), a, pdb_path def update_dropdown(x, label): return gr.Dropdown.update(choices=x, label=label) def load_and_plot_celltype(celltype_name, GET_CONFIG, cell): celltype_id = cell_type_name_to_id[celltype_name] cell = GETCellType(celltype_id, GET_CONFIG) cell.celltype_name = celltype_name gene_exp_fig = cell.plotly_gene_exp() gene_exp_table = cell.gene_annot.groupby('gene_name')[['pred', 'obs', 'accessibility']].mean().reset_index() return gene_exp_fig, gene_exp_table, cell def plot_gene_regions(cell, gene_name, plotly=True): return cell.plot_gene_regions(gene_name, plotly=plotly), cell def plot_gene_motifs(cell, gene_name, motif, overwrite=False): return cell.plot_gene_motifs(gene_name, motif, overwrite=overwrite)[0], cell def plot_motif_subnet(cell, motif_collection, m, type='neighbors', threshold=0.1): return cell.plotly_motif_subnet(motif_collection, m, type=type, threshold=threshold), cell def plot_gene_exp(cell, plotly=True): return cell.plotly_gene_exp(plotly=plotly), cell def plot_motif_corr(cell): fig = Clustergram(data=cell.gene_by_motif.corr, column_labels=list(cell.gene_by_motif.corr.columns.values), row_labels=list(cell.gene_by_motif.corr.index), hidden_labels=['row', 'col'], link_method='average', display_ratio=0.1, width=600, height=400, color_map='rdbu_r', ) return fig, cell #%% # fill this in... # main if __name__ == '__main__': with gr.Blocks(theme='sudeepshouche/minimalist') as demo: seg_pairs = gr.State(['']) af = gr.State(None) cell = gr.State(None) with gr.Row() as row: # Left column: Plot gene expression and gene regions with gr.Column(): with gr.Row() as row: celltype_name = gr.Dropdown(label='Cell Type', choices=avaliable_celltypes) celltype_btn = gr.Button(value='Load & Plot Gene Expression') gene_exp_plot = gr.Plot(label='Gene Expression Pred vs Obs') gene_exp_table = gr.DataFrame(label='Gene Expression Table', max_rows=10) # Right column: Plot gene motifs with gr.Column(): gene_name_for_region = gr.Textbox(label='Get important regions or motifs for gene:') with gr.Row() as row: region_plot_btn = gr.Button(value='Regions') motif_plot_btn = gr.Button(value='Motifs') region_plot = gr.Plot(label='Gene Regions') motif_plot = gr.Plot(label='Gene Motifs') with gr.Row() as row: with gr.Column(): clustergram_btn = gr.Button(value='Plot Motif Correlation Heatmap') clustergram_plot = gr.Plot(label='Motif Correlation') # Right column: Motif subnet plot with gr.Column(): with gr.Row() as row: motif_for_subnet = gr.Dropdown(label='Motif Causal Subnetwork', choices=motif.cluster_names) subnet_type = gr.Dropdown(label='Type', choices=['neighbors', 'parents', 'children'], default='neighbors') # slider for threshold 0.01-0.2 subnet_threshold = gr.Slider(label='Threshold', minimum=0.01, maximum=0.25, step=0.01, value=0.1) subnet_btn = gr.Button(value='Plot Motif Causal Subnetwork') subnet_plot = gr.Plot(label='Motif Causal Subnetwork') with gr.Row() as row: with gr.Column(): with gr.Row() as row: tf_pairs = gr.Dropdown(label='TF pair', choices=gene_pairs) tf_pairs_btn = gr.Button(value='Load & Plot') interact_plddt1 = gr.Plot(label='Interact pLDDT 1') interact_plddt2 = gr.Plot(label='Interact pLDDT 2') protein1_plddt = gr.Plot(label='Protein 1 pLDDT') protein2_plddt = gr.Plot(label='Protein 2 pLDDT') heatmap = gr.Plot(label='Heatmap') with gr.Column(): with gr.Row() as row: segpair = gr.Dropdown(label='Seg pair', choices=seg_pairs.value) segpair_btn = gr.Button(value='Get PDB') pdb_html = gr.HTML(label="PDB HTML") pdb_file = gr.File(label='Download PDB') tf_pairs_btn.click(visualize_AF2, inputs = [tf_pairs, af], outputs = [ interact_plddt1, interact_plddt2, protein1_plddt, protein2_plddt, heatmap, segpair, af]) segpair_btn.click(view_pdb, inputs=[segpair, af], outputs=[pdb_html, af, pdb_file]) celltype_btn.click(load_and_plot_celltype, inputs=[celltype_name, gr.State(GET_CONFIG), cell], outputs=[gene_exp_plot, gene_exp_table, cell]) region_plot_btn.click(plot_gene_regions, inputs=[cell, gene_name_for_region], outputs=[region_plot, cell]) motif_plot_btn.click(plot_gene_motifs, inputs=[cell, gene_name_for_region, gr.State(motif)], outputs=[motif_plot, cell]) clustergram_btn.click(plot_motif_corr, inputs=[cell], outputs=[clustergram_plot, cell]) subnet_btn.click(plot_motif_subnet, inputs=[cell, gr.State(motif), motif_for_subnet, subnet_type, subnet_threshold], outputs=[subnet_plot, cell]) demo.launch(share=args.share, server_port=args.port) # %%