SHEPHERD / app.py
emilyalsentzer's picture
Update app.py
5b2b768 verified
import gradio as gr
import pandas as pd
from pathlib import Path
import ast
gene_scores_df = pd.read_csv('gene_discovery_scores.csv')
exomiser_gene_scores_df = pd.read_csv('exomiser_gene_discovery_scores.csv')
patient_scores_df = pd.read_csv('patients_like_me_scores.csv')
dx_scores_df = pd.read_csv('dx_characterization_scores.csv')
plm_attn_df = pd.read_csv('patients_like_me_scores_attn.csv')
dx_attn_df = pd.read_csv('dx_characterization_scores_attn.csv')
gene_attn_df = pd.read_csv('gene_discovery_scores_attn.csv')
exomiser_gene_attn_df = pd.read_csv('exomiser_gene_discovery_scores_attn.csv')
diseases_map = {'UDN-P1': 'POLR3-releated leukodystrophy', 'UDN-P2': 'Novel PRKAR1B-Related Neurodevelopmental Disorder', 'UDN-P3':'Coffin-Lowry syndrome' ,
'UDN-P4': 'automsomal recessive spastic paraplegia type 76', 'UDN-P5': 'atypical presentation of familial cold autoinflammatory syndrome',
'UDN-P6': '*GATAD2B*-associated syndrome', 'UDN-P7': 'AR limb-girdle muscular atrophy type 2D', 'UDN-P8': '*ATP5PO*-related Leigh syndrome', 'UDN-P9': 'Spondyloepimetaphyseal dysplasia, Isidor-Toutain type'}
genes_map = {'UDN-P3': 'RPS6KA3', 'UDN-P4': 'CAPN1', 'UDN-P5': 'NLRP12, RAPGEFL1', 'UDN-P6': 'GATAD2B', 'UDN-P7': 'SGCA', 'UDN-P8': 'ATP5P0', 'UDN-P9': 'RPL13'}
def get_patient(patient_id, attn_df):
'''
Returns phenotypes, candidate genes, Causal gene, disease
'''
if patient_id in genes_map: gene = genes_map[patient_id]
else:
patient_gene_scores_df = gene_scores_df.loc[gene_scores_df['patient_id'] == patient_id]
gene = ', '.join(patient_gene_scores_df.loc[patient_gene_scores_df['correct_gene_label'] == 1, 'genes'].tolist())
if patient_id in diseases_map: disease = diseases_map[patient_id]
else:
patient_dx_scores_df = dx_scores_df.loc[dx_scores_df['patient_id'] == patient_id]
disease = ', '.join(patient_dx_scores_df.loc[patient_dx_scores_df['correct_label'] == 1, 'diseases'].tolist())
patient_attn_df = attn_df.loc[attn_df['patient_id'] == patient_id]
phenotypes = ', '.join(patient_attn_df['phenotypes'].tolist())
patient_str = f'''
**Selected Patient:** {patient_id}<br>
**Causal Gene:** *{gene}*<br>
**Disease:** {disease}<br>
**Phenotypes:** {phenotypes}<br><br>
'''
return patient_str
def read_file(filename):
with open(filename, 'r') as file:
f = file.read()
return f
def causal_gene_discovery(patient_id, prioritization_type):
if prioritization_type == 'Variant Filtered':
scores_df = exomiser_gene_scores_df.loc[exomiser_gene_scores_df['patient_id'] == patient_id]
else:
scores_df = gene_scores_df.loc[gene_scores_df['patient_id'] == patient_id]
# read in gene scores
scores_df = scores_df.sort_values("similarities", ascending=False)
scores_df['similarities'] = scores_df['similarities'].round(3).astype(str)
# add links to gene cards
scores_df['genes'] = scores_df['genes'].apply(lambda x: f'<u>[{x}](https://www.genecards.org/cgi-bin/carddisp.pl?gene={x})</u>')
# bold/color causal gene
scores_df.loc[scores_df['correct_gene_label'] == 1, 'similarities'] = scores_df.loc[scores_df['correct_gene_label'] == 1, 'similarities'].apply(lambda x: f'<span style="color:green">**{x}**</span>')
scores_df.loc[scores_df['correct_gene_label'] == 1, 'genes'] = scores_df.loc[scores_df['correct_gene_label'] == 1, 'genes'].apply(lambda x: f'<span style="color:green">**{x}**</span>')
#filter df
scores_df = scores_df.drop(columns=['patient_id', 'correct_gene_label']).rename(columns={ 'similarities': 'SHEPHERD Score', 'genes': 'Candidate Genes'}) #'correct_gene_label' : 'Is Causal Gene',
#############
# Attention
#read in phenotype attention
if prioritization_type == 'Variant Filtered':
attn_df = exomiser_gene_attn_df.loc[exomiser_gene_attn_df['patient_id'] == patient_id]
else:
attn_df = gene_attn_df.loc[gene_attn_df['patient_id'] == patient_id]
attn_df = attn_df.sort_values("attention", ascending=False)
attn_df['attention'] = attn_df['attention'].round(4)
attn_df = attn_df.drop(columns=['patient_id', 'degrees'])
#############
# KG neighborhood
#image_loc = f'images/{patient_id}.png'
html_file = f'https://michellemli.github.io/test_html/{patient_id}.html'
kg_html = f'''<iframe id="igraph" scrolling="no" style="border:none; width: 100%; height: 600px" seamless="seamless" src="{html_file}"></iframe>'''
#patient_info
patient = get_patient(patient_id, gene_attn_df)
return patient, scores_df, attn_df, kg_html
def patients_like_me(patient_id, k=10):
scores_df = patient_scores_df.loc[patient_scores_df['patient_id'] == patient_id]
scores_df = scores_df.sort_values("similarities", ascending=False)
#scores_df['phenotypes'] ='PHEN'
# add links to disease pages
scores_df['disease_ids'] = scores_df['disease_ids'].apply(lambda x: f'(https://www.orpha.net/consor/cgi-bin/OC_Exp.php?lng=en&Expert={x})</u>')
scores_df['diseases'] = scores_df['diseases'].apply(lambda x: f'<u>[{x}]')
scores_df['diseases'] = scores_df['diseases'] + scores_df['disease_ids']
scores_df['genes'] = scores_df['genes'].apply(lambda x: f'<u>[{x}](https://www.genecards.org/cgi-bin/carddisp.pl?gene={x})</u>')
# bold/color patients with same causal gene
scores_df.loc[scores_df['correct_label'] == 1, 'candidate_patients'] = scores_df.loc[scores_df['correct_label'] == 1, 'candidate_patients'].apply(lambda x: f'<span style="color:green">**{x}**</span>')
scores_df.loc[scores_df['correct_label'] == 1, 'genes'] = scores_df.loc[scores_df['correct_label'] == 1, 'genes'].apply(lambda x: f'<span style="color:green">**{x}**</span>')
scores_df.loc[scores_df['correct_label'] == 1, 'diseases'] = scores_df.loc[scores_df['correct_label'] == 1, 'diseases'].apply(lambda x: f'<span style="color:green">**{x}**</span>')
scores_df = scores_df.drop(columns=['patient_id', 'similarities', 'correct_label', 'disease_ids']).rename(columns={'candidate_patients': 'Candidate Patient', 'genes': 'Candidate Patient\'s Gene', 'diseases': 'Candidate Patient\'s Disease' }) #'phenotypes': 'Candidate Patient\'s Phenotypes'
scores_df = scores_df.head(k)
#read in phenotype attention
attn_df = plm_attn_df.loc[plm_attn_df['patient_id'] == patient_id]
attn_df = attn_df.sort_values("attention", ascending=False)
attn_df['attention'] = attn_df['attention'].round(4)
attn_df = attn_df.drop(columns=['patient_id', 'degrees'])
#patient_info
patient = get_patient(patient_id, plm_attn_df)
return patient, scores_df, attn_df
def disease_characterization(patient_id, k=10):
#TODO: limit # of rows
scores_df = dx_scores_df.loc[dx_scores_df['patient_id'] == patient_id]
scores_df = scores_df.sort_values("similarities", ascending=False)
scores_df = scores_df.head(k)
scores_df.loc[ scores_df['disease_ids'].str.contains('Coxa vara'), 'disease_ids'] = '2812'
scores_df.loc[ scores_df['disease_ids'].str.contains('Multiple epiphyseal dysplasia'), 'disease_ids'] = '2654'
scores_df['disease_ids'] = scores_df['disease_ids'].apply(lambda x: ast.literal_eval(x))
scores_df['type_disease_ids'] = scores_df['disease_ids'].apply(lambda x: type(x))
scores_df.loc[scores_df['type_disease_ids'] == list, 'disease_ids'] = scores_df.loc[scores_df['type_disease_ids'] == list, 'disease_ids'].apply(lambda x: x[0])
# add links to disease pages
scores_df['disease_ids'] = scores_df['disease_ids'].apply(lambda x: f'(https://www.orpha.net/consor/cgi-bin/OC_Exp.php?lng=en&Expert={x})</u>')
scores_df['diseases'] = scores_df['diseases'].apply(lambda x: f'<u>[{x}]')
scores_df['diseases'] = scores_df['diseases'] + scores_df['disease_ids']
# one disease couldn't map to orphanet
scores_df.loc[ scores_df['disease_ids'].str.contains('33657'), 'diseases'] = '<u>[leukodystrophy, hypomyelinating, 20](https://www.omim.org/entry/619071)</u>'
scores_df.loc[ scores_df['disease_ids'].str.contains('2654'), 'diseases'] = '<u>[Multiple epiphyseal dysplasia](https://www.orpha.net/consor/cgi-bin/OC_Exp.php?lng=EN&Expert=251)</u>'
scores_df.loc[ scores_df['disease_ids'].str.contains('2812'), 'diseases'] = '<u>[Coxa vara](https://omim.org/entry/122750)</u>'
scores_df = scores_df.drop(columns=['patient_id', 'similarities', 'correct_label', 'disease_ids','type_disease_ids']).rename(columns={'diseases' : 'Disease'})
#read in phenotype attention
attn_df = dx_attn_df.loc[dx_attn_df['patient_id'] == patient_id]
attn_df = attn_df.sort_values("attention", ascending=False)
attn_df['attention'] = attn_df['attention'].round(4)
attn_df = attn_df.drop(columns=['patient_id', 'degrees'])
#patient_info
patient = get_patient(patient_id, dx_attn_df)
return patient, scores_df, attn_df
def get_umap(umap_type):
# get UMAP
if umap_type == 'disease':
html_file = 'https://michellemli.github.io/test_html/shepherd_disease_characterization_umap.html'
#html_file = read_file('images/udn_orphafit_patient_umap_nneigh=50_mindist=0.9_spread=1.0colored_by_disease_category.html')
elif umap_type == 'patient':
html_file = 'https://michellemli.github.io/test_html/shepherd_patient_umap.html'
else:
raise NotImplementedError
# return f"""<iframe style="width: 100%; height: 480px" name="result" allow="midi;
# display-capture; encrypted-media;" sandbox="allow-modals allow-forms
# allow-scripts allow-same-origin allow-popups
# allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
# allowpaymentrequest="" frameborder="0" srcdoc='{html_file}'></iframe>"""
return f'''<embed style="border: none;" src="{html_file}" dpi="300" width="100%" height="750px" />'''
#return f'''<iframe id="igraph" scrolling="no" style="border:none; width: 100%; height: 750px" seamless="seamless" src="{html_file}"></iframe>'''
with gr.Blocks() as demo: #css="#gene_attn_accordion {text-align: center}" css="kg_neigh {width: 70%}"
gr.Markdown('<center><h1>AI-assisted Rare Disease Diagnosis with SHEPHERD</h1></center>')
#gr.Markdown('<center><h2>A few SHot Explainable Predictor for Hard-to-diagnosE Rare Diseases</h2></center>')
with gr.Tabs():
with gr.TabItem("Causal Gene Discovery"):
with gr.Column():
gr.Markdown('<center><h2>Select a patient to view SHEPHERD\'s predictions</h2></center>')
gene_dropdown = gr.Dropdown(choices=['UDN-P1', 'UDN-P2'], label='Rare Disease Patients', type='value') #value='UDN-P1',
gene_radio = gr.Radio(choices=['Expert Curated', 'Variant Filtered'], value='Expert Curated', label='Type of Gene List')
patient_info = gr.Markdown() #get_patient('UDN-P1')
with gr.Accordion(label=f'SHEPHERD\'s Ranking of Patient\'s Candidate Genes', open=True, elem_id='gene_accordion'):
#gr.Markdown(f'<center><h3>SHEPHERD\'s Ranking of Patient\'s Candidate Genes</h3></center>')
gr.Markdown('Below are SHEPHERD\'s ranking of either all Expert Curated candidate genes or the top 10 Variant Filtered candidate genes. The patient\'s causal gene (i.e. gene harboring a variant that explains the patient\'s symptoms) is colored in green.')
gene_dataframe = gr.Dataframe( elem_id="gene_df", datatype = 'markdown', headers=['Candidate Genes', 'SHEPHERD Score' ], overflow_row_behaviour='paginate') # label='Candidate Genes', show_label=False,
with gr.Accordion(label=f'SHEPHERD\'s Attention to Patient\'s Phenotypes', open=False, elem_id='gene_attn_accordion'):
#gr.Markdown(f'<center><h3>SHEPHERD\'s Attention to Patient\'s Phenotypes</h3></center>')
gene_attn_dataframe = gr.Dataframe( elem_id="gene_attn_df", headers=['Phenotypes', 'Attention' ], overflow_row_behaviour='paginate') # label='Candidate Genes', show_label=False,
with gr.Accordion(label=f'Visualization of Patient\'s Neighborhood in the Knowledge Graph', open=False, elem_id='kg_neigh_accordion'):
#kg_neighborhood_image = gr.Image(elem_id='kg_neigh')#.style(height=200, width=200)
kg_neighborhood_image = gr.HTML(elem_id = 'kg_patient_neighborhood')
#gene_button = gr.Button("Go")
with gr.TabItem("Patients Like Me"):
gr.HTML(get_umap('patient'))
gr.Markdown('<center><h2>Select a patient to view SHEPHERD\'s predictions</h2></center>')
patient_dropdown = gr.Dropdown(choices=['UDN-P3','UDN-P4','UDN-P5','UDN-P6'], label='Rare Disease Patients', type='value')
p_patient_info = gr.Markdown()
with gr.Accordion(label=f'Top 10 Most Similar Patients according to SHEPHERD', open=True, elem_id='pt_accordion'): #
patient_dataframe = gr.Dataframe(max_rows=10, datatype = 'markdown', show_label=False, elem_id="pat_df", headers=['Candidate Patient', 'Candidate Patient\'s Gene', 'Candidate Patient\'s Disease' ]) #'Candidate Patient\'s Phenotypes'
#patient_button = gr.Button("Go")
with gr.Accordion(label='SHEPHERD\'s Attention to Patient\'s Phenotypes', open=False, elem_id='pt_attn_accordion'):
pt_attn_dataframe = gr.Dataframe( elem_id="pt_attn_df", headers=['Phenotypes', 'Attention' ], overflow_row_behaviour='paginate')
with gr.TabItem("Disease Characterization"):
gr.HTML(get_umap('disease'))
gr.Markdown('<center><h2>Select a patient to view SHEPHERD\'s predictions</h2></center>')
dx_dropdown = gr.Dropdown(choices=['UDN-P7','UDN-P8','UDN-P9','UDN-P2'], label='Rare Disease Patients', type='value')
dx_patient_info = gr.Markdown()
with gr.Accordion(label='Top 10 Most Similar Diseases according to SHEPHERD', open=True, elem_id='pt_accordion'): #
dx_dataframe = gr.Dataframe(max_rows=10, datatype = 'markdown', show_label=False, elem_id="dx_df", headers=['Diseases'])
with gr.Accordion(label='SHEPHERD\'s Attention to Patient\'s Phenotypes', open=False, elem_id='dx_attn_accordion'):
dx_attn_dataframe = gr.Dataframe( elem_id="dx_attn_df", headers=['Phenotypes', 'Attention' ], overflow_row_behaviour='paginate')
#dx_button = gr.Button("Go")
gene_dropdown.change(causal_gene_discovery, inputs=[gene_dropdown,gene_radio], outputs=[patient_info, gene_dataframe, gene_attn_dataframe, kg_neighborhood_image])
gene_radio.change(causal_gene_discovery, inputs=[gene_dropdown,gene_radio], outputs=[patient_info, gene_dataframe, gene_attn_dataframe, kg_neighborhood_image])
patient_dropdown.change(patients_like_me, inputs=patient_dropdown, outputs=[p_patient_info, patient_dataframe, pt_attn_dataframe])
dx_dropdown.change(disease_characterization, inputs=dx_dropdown, outputs=[dx_patient_info, dx_dataframe, dx_attn_dataframe])
#gene_dropdown.change(get_patient, inputs=gene_dropdown, outputs=patient_info)
#gene_button.click(causal_gene_discovery, inputs=gene_dropdown, outputs=[gene_dataframe,gene_attn_dataframe, kg_neighborhood_image])
#patient_button.click(patients_like_me, inputs=patient_dropdown, outputs=patient_dataframe)
#dx_button.click(disease_characterization, inputs=dx_dropdown, outputs=dx_dataframe)
demo.launch( ) #server_port=50018, share=True