File size: 8,756 Bytes
12761b6
 
8bdf52a
 
 
 
12761b6
 
 
 
 
8bdf52a
 
 
12761b6
 
 
 
 
 
 
8bdf52a
 
 
 
12761b6
8bdf52a
 
 
 
12761b6
 
 
 
 
8bdf52a
12761b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bdf52a
 
 
12761b6
 
8bdf52a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12761b6
8bdf52a
 
 
12761b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bdf52a
 
 
 
 
 
 
12761b6
 
8bdf52a
12761b6
8bdf52a
12761b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bdf52a
 
 
 
 
 
 
 
12761b6
 
 
8bdf52a
 
 
 
 
12761b6
 
 
 
 
8bdf52a
 
 
12761b6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#%%
import argparse
import os

import gradio as gr
import matplotlib.pyplot as plt
import pkg_resources
from proscope.data import get_genename_to_uniprot, get_lddt, get_seq
import pandas as pd 
from dash_bio import Clustergram

seq = get_seq()
genename_to_uniprot = get_genename_to_uniprot()
lddt = get_lddt()
import sys
from glob import glob

import numpy as np
from atac_rna_data_processing.config.load_config import load_config
from atac_rna_data_processing.io.celltype import GETCellType
from atac_rna_data_processing.io.nr_motif_v1 import NrMotifV1
from proscope.af2 import AFPairseg
from proscope.protein import Protein
from proscope.viewer import view_pdb_html

#%%
args = argparse.ArgumentParser()
args.add_argument("-p", "--port", type=int, default=7860, help="Port number")
args.add_argument("-s", "--share", action="store_true", help="Share on network")
args.add_argument("-d", "--data", type=str, default="/data", help="Data directory")
# args = args.parse_args()
# set pseudo args
args = args.parse_args(['-p', '7869', '-s', '-d', '/manitou/pmg/users/xf2217/demo_data'])
#%%
gene_pairs = glob(f"{args.data}/structures/causal/*")
gene_pairs = [os.path.basename(pair) for pair in gene_pairs]
GET_CONFIG = load_config('/manitou/pmg/users/xf2217/atac_rna_data_processing/atac_rna_data_processing/config/GET')
GET_CONFIG.celltype.jacob=True
GET_CONFIG.celltype.num_cls=2
GET_CONFIG.celltype.input=True
GET_CONFIG.celltype.embed=True
GET_CONFIG.celltype.data_dir = '/manitou/pmg/users/xf2217/pretrain_human_bingren_shendure_apr2023/fetal_adult/'
GET_CONFIG.celltype.interpret_dir='/manitou/pmg/users/xf2217/Interpretation_all_hg38_allembed_v4_natac/'
GET_CONFIG.motif_dir = '/manitou/pmg/users/xf2217/interpret_natac/motif-clustering'
motif = NrMotifV1.load_from_pickle(
    pkg_resources.resource_filename("atac_rna_data_processing", "data/NrMotifV1.pkl"),
    GET_CONFIG.motif_dir
)
cell_type_annot = pd.read_csv(GET_CONFIG.celltype.data_dir.split('fetal_adult')[0] + 'data/cell_type_pretrain_human_bingren_shendure_apr2023.txt')
cell_type_id_to_name = dict(zip(cell_type_annot['id'], cell_type_annot['celltype']))
cell_type_name_to_id = dict(zip(cell_type_annot['celltype'], cell_type_annot['id']))
avaliable_celltypes = sorted([cell_type_id_to_name[f.split('/')[-1]] for f in glob(GET_CONFIG.celltype.interpret_dir+'*')])
#%%
# fill this in...
# set plot ppi to 100
plt.rcParams['figure.dpi'] = 100



def visualize_AF2(tf_pair, a):
    strcture_dir = f"{args.data}/structures/causal/{tf_pair}"
    fasta_dir = f"{args.data}/sequences/causal/{tf_pair}"
    if not os.path.exists(strcture_dir):
        gr.ErrorText("No such gene pair")

    a = AFPairseg(strcture_dir, fasta_dir)
    segpair.choices = list(a.pairs_data.keys())
    fig1, ax1 = a.plot_plddt_gene1()
    fig2, ax2 = a.plot_plddt_gene2()
    fig3, ax3 = a.protein1.plot_plddt()
    fig4, ax4 = a.protein2.plot_plddt()
    fig5, ax5 = a.plot_score_heatmap()
    plt.tight_layout()
    new_dropdown = update_dropdown(list(a.pairs_data.keys()), 'Segment pair')
    return fig1, fig2, fig3, fig4, fig5, new_dropdown, a

def view_pdb(seg_pair, a):
    pdb_path = a.pairs_data[seg_pair].pdb
    return view_pdb_html(pdb_path), a, pdb_path



def update_dropdown(x, label):
    return gr.Dropdown.update(choices=x, label=label)

def load_and_plot_celltype(celltype_name, GET_CONFIG, cell):
    celltype_id = cell_type_name_to_id[celltype_name]
    cell = GETCellType(celltype_id, GET_CONFIG)
    cell.celltype_name = celltype_name
    gene_exp_fig = cell.plotly_gene_exp()
    gene_exp_table = cell.gene_annot.groupby('gene_name')[['pred', 'obs', 'accessibility']].mean().reset_index()
    return gene_exp_fig, gene_exp_table, cell
    
def plot_gene_regions(cell, gene_name, plotly=True):
    return cell.plot_gene_regions(gene_name, plotly=plotly), cell

def plot_gene_motifs(cell, gene_name, motif, overwrite=False):
    return cell.plot_gene_motifs(gene_name, motif, overwrite=overwrite)[0], cell

def plot_motif_subnet(cell, motif_collection, m, type='neighbors', threshold=0.1):
    return cell.plotly_motif_subnet(motif_collection, m, type=type, threshold=threshold), cell

def plot_gene_exp(cell, plotly=True):
    return cell.plotly_gene_exp(plotly=plotly), cell

def plot_motif_corr(cell):
    fig = Clustergram(data=cell.gene_by_motif.corr,
                    column_labels=list(cell.gene_by_motif.corr.columns.values),
                    row_labels=list(cell.gene_by_motif.corr.index),
                    hidden_labels=['row', 'col'],
                    link_method='average',
                    display_ratio=0.1,
                    width=600,
                    height=400,
                    color_map='rdbu_r',
                    )
    return fig, cell

#%%
# fill this in...

# main
if __name__ == '__main__':
    with gr.Blocks(theme='sudeepshouche/minimalist') as demo:

        seg_pairs = gr.State([''])
        af = gr.State(None)
        cell = gr.State(None)
        
        with gr.Row() as row:
            # Left column: Plot gene expression and gene regions
            with gr.Column():
                with gr.Row() as row:
                    celltype_name = gr.Dropdown(label='Cell Type', choices=avaliable_celltypes)  
                    celltype_btn = gr.Button(value='Load & Plot Gene Expression')
                gene_exp_plot = gr.Plot(label='Gene Expression Pred vs Obs')
                gene_exp_table = gr.DataFrame(label='Gene Expression Table', max_rows=10)
                
            # Right column: Plot gene motifs
            with gr.Column():
                gene_name_for_region = gr.Textbox(label='Get important regions or motifs for gene:')
                with gr.Row() as row:
                    region_plot_btn = gr.Button(value='Regions')
                    motif_plot_btn = gr.Button(value='Motifs')

                region_plot = gr.Plot(label='Gene Regions')
                motif_plot = gr.Plot(label='Gene Motifs')

        
        with gr.Row() as row:
            with gr.Column():
                clustergram_btn = gr.Button(value='Plot Motif Correlation Heatmap')
                clustergram_plot = gr.Plot(label='Motif Correlation')

            
            # Right column: Motif subnet plot
            with gr.Column():
                with gr.Row() as row:
                    motif_for_subnet = gr.Dropdown(label='Motif Causal Subnetwork', choices=motif.cluster_names)
                    subnet_type = gr.Dropdown(label='Type', choices=['neighbors', 'parents', 'children'], default='neighbors')
                    # slider for threshold 0.01-0.2
                    subnet_threshold = gr.Slider(label='Threshold', minimum=0.01, maximum=0.25, step=0.01, value=0.1)
                subnet_btn = gr.Button(value='Plot Motif Causal Subnetwork')
                subnet_plot = gr.Plot(label='Motif Causal Subnetwork')
                

        with gr.Row() as row:
            with gr.Column():
                with gr.Row() as row:
                    tf_pairs = gr.Dropdown(label='TF pair', choices=gene_pairs)
                    tf_pairs_btn = gr.Button(value='Load & Plot')
                interact_plddt1 = gr.Plot(label='Interact pLDDT 1')
                interact_plddt2 = gr.Plot(label='Interact pLDDT 2')
                protein1_plddt = gr.Plot(label='Protein 1 pLDDT')
                protein2_plddt = gr.Plot(label='Protein 2 pLDDT')

                heatmap = gr.Plot(label='Heatmap')
            
            with gr.Column():
                with gr.Row() as row:
                    segpair = gr.Dropdown(label='Seg pair', choices=seg_pairs.value)
                    segpair_btn = gr.Button(value='Get PDB')
                pdb_html = gr.HTML(label="PDB HTML")
                pdb_file = gr.File(label='Download PDB')

        tf_pairs_btn.click(visualize_AF2, inputs = [tf_pairs, af], outputs = [ interact_plddt1, interact_plddt2, protein1_plddt, protein2_plddt, heatmap, segpair, af])
        segpair_btn.click(view_pdb, inputs=[segpair, af], outputs=[pdb_html, af, pdb_file])
        celltype_btn.click(load_and_plot_celltype, inputs=[celltype_name, gr.State(GET_CONFIG), cell], outputs=[gene_exp_plot, gene_exp_table, cell])
        region_plot_btn.click(plot_gene_regions, inputs=[cell, gene_name_for_region], outputs=[region_plot, cell])
        motif_plot_btn.click(plot_gene_motifs, inputs=[cell, gene_name_for_region, gr.State(motif)], outputs=[motif_plot, cell])
        clustergram_btn.click(plot_motif_corr, inputs=[cell], outputs=[clustergram_plot, cell])
        subnet_btn.click(plot_motif_subnet, inputs=[cell, gr.State(motif), motif_for_subnet, subnet_type, subnet_threshold], outputs=[subnet_plot, cell])

    demo.launch(share=args.share, server_port=args.port)


# %%