Spaces:
Running
Running
| import collections | |
| import os | |
| from typing import Dict, List | |
| import gradio as gr | |
| from index_list import read_index_list | |
| from protein_viz import get_pdb_title, render_html | |
| from search_engine import MilvusParams, ProteinSearchEngine | |
| model_repo = "ronig/protein_biencoder" | |
| available_indexes = read_index_list() | |
| engine = ProteinSearchEngine( | |
| milvus_params=MilvusParams( | |
| uri="https://in03-ddab8e9a5a09fcc.api.gcp-us-west1.zillizcloud.com", | |
| token=os.environ.get("MILVUS_TOKEN"), | |
| db_name="Protein", | |
| collection_name="Peptriever", | |
| ), | |
| model_repo=model_repo, | |
| ) | |
| max_results = 1000 | |
| choice_sep = " | " | |
| max_seq_length = 50 | |
| def search_and_display(seq, max_res, index_selection): | |
| n_search_res = 1024 | |
| _validate_sequence_length(seq) | |
| max_res = int(limit_n_results(max_res)) | |
| if index_selection == "All Species": | |
| index_selection = None | |
| search_res = engine.search_by_sequence( | |
| seq, n=n_search_res, organism=index_selection | |
| ) | |
| agg_search_results = aggregate_search_results(search_res, max_res) | |
| formatted_search_results = format_search_results(agg_search_results) | |
| results_options = update_dropdown_menu(agg_search_results) | |
| return formatted_search_results, results_options | |
| def _validate_sequence_length(seq): | |
| if len(seq) > max_seq_length: | |
| raise gr.Error("Only peptide input is currently supported") | |
| def limit_n_results(n): | |
| return max(min(n, max_results), 1) | |
| def aggregate_search_results(raw_results: List[dict], max_res: int) -> Dict[str, dict]: | |
| aggregated_by_uniprot = collections.defaultdict(list) | |
| for raw_result in raw_results: | |
| entry = select_keys( | |
| raw_result, | |
| keys=["pdb_name", "chain_id", "score", "organism", "uniprot_id", "genes"], | |
| ) | |
| uniprot_id = raw_result["uniprot_id"] | |
| if uniprot_id is not None: | |
| aggregated_by_uniprot[uniprot_id].append(entry) | |
| if len(aggregated_by_uniprot) >= max_res: | |
| return dict(aggregated_by_uniprot) | |
| return dict(aggregated_by_uniprot) | |
| def select_keys(d: dict, keys: List[str]): | |
| return {key: d[key] for key in keys} | |
| def format_search_results(agg_search_results): | |
| formatted_search_results = {} | |
| for uniprot_id, entries in agg_search_results.items(): | |
| entry = entries[0] | |
| organism = entry["organism"] | |
| score = entry["score"] | |
| genes = entry["genes"] | |
| key = f"Uniprot ID: {uniprot_id} | Organism: {organism} | Gene Names: {genes}" | |
| formatted_search_results[key] = score | |
| return formatted_search_results | |
| def update_dropdown_menu(agg_search_res): | |
| choices = [] | |
| for uniprot_id, entries in agg_search_res.items(): | |
| for entry in entries: | |
| choice = choice_sep.join( | |
| [ | |
| uniprot_id, | |
| entry["pdb_name"], | |
| entry["chain_id"], | |
| entry["genes"] or "", | |
| ] | |
| ) | |
| choices.append(choice) | |
| if choices: | |
| update = gr.update( | |
| gr.Dropdown.get_component_class_id(), | |
| choices=choices, | |
| interactive=True, | |
| value=choices[0], | |
| visible=True, | |
| ) | |
| else: | |
| update = gr.update( | |
| gr.Dropdown.get_component_class_id(), | |
| choices=choices, | |
| interactive=True, | |
| visible=False, | |
| value=None, | |
| ) | |
| return update | |
| def parse_pdb_search_result(raw_result): | |
| prot = raw_result["pdb_name"] | |
| chain = raw_result["chain_id"] | |
| value = raw_result["score"] | |
| gene_names = raw_result["genes"] | |
| species = raw_result["organism"] | |
| key = f"PDB: {prot}.{chain}" | |
| if gene_names is not None: | |
| key += f" | Genes: {gene_names} | Organism: {species}" | |
| return key, value | |
| def switch_viz(new_choice): | |
| if new_choice is None: | |
| html = "" | |
| title_update = gr.update(gr.Markdown.get_component_class_id(), visible=False) | |
| description_update = gr.update( | |
| gr.Markdown.get_component_class_id(), value=None, visible=False | |
| ) | |
| else: | |
| choice_parts = new_choice.split(choice_sep) | |
| pdb_id, chain = choice_parts[1:3] | |
| title_update = gr.update(gr.Markdown.get_component_class_id(), visible=True) | |
| pdb_title = get_pdb_title(pdb_id) | |
| new_value = f"""**PDB Title**: {pdb_title}""" | |
| description_update = gr.update( | |
| gr.Markdown.get_component_class_id(), value=new_value, visible=True | |
| ) | |
| html = render_html(pdb_id=pdb_id, chain=chain) | |
| return html, title_update, description_update | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| seq_input = gr.Textbox(value="APTMPPPLPP", label="Input Sequence") | |
| n_results = gr.Number(10, label="N Results") | |
| index_selector = gr.Dropdown( | |
| choices=available_indexes, | |
| value="All Species", | |
| multiselect=False, | |
| visible=True, | |
| label="Index", | |
| ) | |
| search_button = gr.Button("Search", variant="primary") | |
| search_results = gr.Label( | |
| num_top_classes=max_results, label="Search Results", scale=2 | |
| ) | |
| viz_header = gr.Markdown("## Visualization", visible=False) | |
| results_selector = gr.Dropdown( | |
| choices=[], | |
| multiselect=False, | |
| visible=False, | |
| label="Visualized Search Result", | |
| ) | |
| viz_body = gr.Markdown("", visible=False) | |
| protein_viz = gr.HTML( | |
| value=render_html(pdb_id=None, chain=None), | |
| label="Protein Visualization", | |
| ) | |
| gr.Examples( | |
| ["APTMPPPLPP", "KFLIYQMECSTMIFGL", "PHFAMPPIHEDHLE", "AEERIISLD"], | |
| inputs=[seq_input], | |
| ) | |
| search_button.click( | |
| search_and_display, | |
| inputs=[seq_input, n_results, index_selector], | |
| outputs=[search_results, results_selector], | |
| ) | |
| results_selector.change( | |
| switch_viz, inputs=results_selector, outputs=[protein_viz, viz_header, viz_body] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |