|
import collections |
|
import os |
|
from typing import Dict, List |
|
|
|
import gradio as gr |
|
|
|
from index_list import read_index_list |
|
from protein_viz import get_pdb_title, render_html |
|
from search_engine import MilvusParams, ProteinSearchEngine |
|
|
|
model_repo = "ronig/protein_biencoder" |
|
|
|
available_indexes = read_index_list() |
|
engine = ProteinSearchEngine( |
|
milvus_params=MilvusParams( |
|
uri="https://in03-ddab8e9a5a09fcc.api.gcp-us-west1.zillizcloud.com", |
|
token=os.environ.get("MILVUS_TOKEN"), |
|
db_name="Protein", |
|
collection_name="Peptriever", |
|
), |
|
model_repo=model_repo, |
|
) |
|
|
|
max_results = 1000 |
|
choice_sep = " | " |
|
max_seq_length = 50 |
|
|
|
|
|
def search_and_display(seq, max_res, index_selection): |
|
n_search_res = 1024 |
|
_validate_sequence_length(seq) |
|
max_res = int(limit_n_results(max_res)) |
|
if index_selection == "All Species": |
|
index_selection = None |
|
search_res = engine.search_by_sequence( |
|
seq, n=n_search_res, organism=index_selection |
|
) |
|
agg_search_results = aggregate_search_results(search_res, max_res) |
|
formatted_search_results = format_search_results(agg_search_results) |
|
results_options = update_dropdown_menu(agg_search_results) |
|
return formatted_search_results, results_options |
|
|
|
|
|
def _validate_sequence_length(seq): |
|
if len(seq) > max_seq_length: |
|
raise gr.Error("Only peptide input is currently supported") |
|
|
|
|
|
def limit_n_results(n): |
|
return max(min(n, max_results), 1) |
|
|
|
|
|
def aggregate_search_results(raw_results: List[dict], max_res: int) -> Dict[str, dict]: |
|
aggregated_by_uniprot = collections.defaultdict(list) |
|
for raw_result in raw_results: |
|
entry = select_keys( |
|
raw_result, |
|
keys=["pdb_name", "chain_id", "score", "organism", "uniprot_id", "genes"], |
|
) |
|
uniprot_id = raw_result["uniprot_id"] |
|
|
|
if uniprot_id is not None: |
|
aggregated_by_uniprot[uniprot_id].append(entry) |
|
if len(aggregated_by_uniprot) >= max_res: |
|
return dict(aggregated_by_uniprot) |
|
return dict(aggregated_by_uniprot) |
|
|
|
|
|
def select_keys(d: dict, keys: List[str]): |
|
return {key: d[key] for key in keys} |
|
|
|
|
|
def format_search_results(agg_search_results): |
|
formatted_search_results = {} |
|
for uniprot_id, entries in agg_search_results.items(): |
|
entry = entries[0] |
|
organism = entry["organism"] |
|
score = entry["score"] |
|
genes = entry["genes"] |
|
key = f"Uniprot ID: {uniprot_id} | Organism: {organism} | Gene Names: {genes}" |
|
formatted_search_results[key] = score |
|
return formatted_search_results |
|
|
|
|
|
def update_dropdown_menu(agg_search_res): |
|
choices = [] |
|
for uniprot_id, entries in agg_search_res.items(): |
|
for entry in entries: |
|
choice = choice_sep.join( |
|
[ |
|
uniprot_id, |
|
entry["pdb_name"], |
|
entry["chain_id"], |
|
entry["genes"] or "", |
|
] |
|
) |
|
choices.append(choice) |
|
|
|
if choices: |
|
update = gr.update( |
|
gr.Dropdown.get_component_class_id(), |
|
choices=choices, |
|
interactive=True, |
|
value=choices[0], |
|
visible=True, |
|
) |
|
|
|
else: |
|
update = gr.update( |
|
gr.Dropdown.get_component_class_id(), |
|
choices=choices, |
|
interactive=True, |
|
visible=False, |
|
value=None, |
|
) |
|
return update |
|
|
|
|
|
def parse_pdb_search_result(raw_result): |
|
prot = raw_result["pdb_name"] |
|
chain = raw_result["chain_id"] |
|
value = raw_result["score"] |
|
gene_names = raw_result["genes"] |
|
species = raw_result["organism"] |
|
key = f"PDB: {prot}.{chain}" |
|
if gene_names is not None: |
|
key += f" | Genes: {gene_names} | Organism: {species}" |
|
return key, value |
|
|
|
|
|
def switch_viz(new_choice): |
|
if new_choice is None: |
|
html = "" |
|
title_update = gr.update(gr.Markdown.get_component_class_id(), visible=False) |
|
description_update = gr.update( |
|
gr.Markdown.get_component_class_id(), value=None, visible=False |
|
) |
|
else: |
|
choice_parts = new_choice.split(choice_sep) |
|
pdb_id, chain = choice_parts[1:3] |
|
title_update = gr.update(gr.Markdown.get_component_class_id(), visible=True) |
|
pdb_title = get_pdb_title(pdb_id) |
|
|
|
new_value = f"""**PDB Title**: {pdb_title}""" |
|
|
|
description_update = gr.update( |
|
gr.Markdown.get_component_class_id(), value=new_value, visible=True |
|
) |
|
html = render_html(pdb_id=pdb_id, chain=chain) |
|
return html, title_update, description_update |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
seq_input = gr.Textbox(value="APTMPPPLPP", label="Input Sequence") |
|
n_results = gr.Number(10, label="N Results") |
|
index_selector = gr.Dropdown( |
|
choices=available_indexes, |
|
value="All Species", |
|
multiselect=False, |
|
visible=True, |
|
label="Index", |
|
) |
|
search_button = gr.Button("Search", variant="primary") |
|
search_results = gr.Label( |
|
num_top_classes=max_results, label="Search Results", scale=2 |
|
) |
|
viz_header = gr.Markdown("## Visualization", visible=False) |
|
results_selector = gr.Dropdown( |
|
choices=[], |
|
multiselect=False, |
|
visible=False, |
|
label="Visualized Search Result", |
|
) |
|
viz_body = gr.Markdown("", visible=False) |
|
protein_viz = gr.HTML( |
|
value=render_html(pdb_id=None, chain=None), |
|
label="Protein Visualization", |
|
) |
|
gr.Examples( |
|
["APTMPPPLPP", "KFLIYQMECSTMIFGL", "PHFAMPPIHEDHLE", "AEERIISLD"], |
|
inputs=[seq_input], |
|
) |
|
search_button.click( |
|
search_and_display, |
|
inputs=[seq_input, n_results, index_selector], |
|
outputs=[search_results, results_selector], |
|
) |
|
results_selector.change( |
|
switch_viz, inputs=results_selector, outputs=[protein_viz, viz_header, viz_body] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|