import gradio as gr from transformers import AutoModel, AutoConfig, logging from huggingface_hub import login import os import json import torch from pathlib import Path import seaborn as sns from collections import defaultdict from gradio_modal import Modal logging.set_verbosity_error() metl_config = AutoConfig.from_pretrained('gitter-lab/METL', trust_remote_code=True, cache_dir='./cache') metl = AutoModel.from_pretrained('gitter-lab/METL', trust_remote_code=True, cache_dir='./cache') pdb_path = None RADIO_CSS = """ #indexing>div { flex-direction: column; width: 50px; } #pdbUpload { height: 150px; } #modelPDBRow{ height: 150px; } #modelInputCol{ gap:0px; } #modelStatus { margin: auto; text-align: center; } .options { height: 300px; !important } .multiModalText > label > div > textarea{ border-style: solid; border-radius: var(--block-radius); border-color: red; !important padding-left: 5px; border-color: var(--border-color-primary); border-width: 1px; } .multiModalText > label > div > button.upload-button { margin-right: 5px; } #wildTypeSequence > label > div > button.upload-button { display: none; } .multiModalText > label > div > button.submit-button { display: none; } .main { width: 50vw; min-width: 700px; margin-left: auto; margin-right: auto; } #moleculeRow > div:not(.form) { flex-grow: 90; overflow: hidden; } #variantCheck > div.wrap > label { width: 100%; } div:has(> button.selectorButton){ flex-direction: row; flex-wrap: nowrap; } .selectorButton{ width: 100px; min-width: 100px; } .selectionHint{ text-align: center; } #helpModal{ position: absolute; bottom: calc(100% - 2.58rem); width: 5rem; left: 101%; } .modal-container{ width: 45rem; } #helpModalText{ font-size:large; } p { font-size: large; } li { font-size: large; } """ RED = "#DA667B" GREEN = "#6B9A5F" def generate_iframe_html(variants): global pdb_path if pdb_path is not None: with open(pdb_path, 'r') as f: mol = f.read() residue_code = "" if variants is not None: if isinstance(variants, str): try: variants = json.loads(variants) except: return 'The variants given were not in a valid JSON list format' if len(variants) <= 9: cmap = sns.color_palette('colorblind').as_hex() del cmap[-3] #Doesn't show up well on the molecule else: #no colorblind support past 9 items. RIP cmap = sns.color_palette('husl', len(variants)).as_hex() duplicate_dict = defaultdict(int) visited_dict = defaultdict(int) # I have to do this twice so I can color the duplicates the duplicate color (black) for index, variant in enumerate(variants): variant_list = variant.split(',') for mutation in variant_list: residue = mutation[1:-1] duplicate_dict[residue] += 1 for index, variant in enumerate(variants): variant_list = variant.split(',') for mutation in variant_list: residue = mutation[1:-1] visited_dict[residue] += 1 if duplicate_dict[residue] > 1 and visited_dict[residue] == 1: continue elif duplicate_dict[residue] > 1: residue_code += 'viewer.getModel(0).setStyle({resi:[' + residue + ']}, {cartoon:{color:"#570606"}});\n' else: residue_code += 'viewer.getModel(0).setStyle({resi:[' + residue + ']}, {cartoon:{color:" ' + cmap[index] + '"}});\n' script = ( """ """) # with open('./scriptTESTVIEW.txt', 'w') as f: # f.write(script) else: script = "" x = ( """
""" + script + """ """) return x def get_iframe(variants): x = generate_iframe_html(variants) return f"""""" def to_zero_based(variants): zero_based = [] for line in variants: line_as_json = json.loads(line) new_variants = [] for variant in line_as_json: new_variant = [] mutations = variant.split(',') for mutation in mutations: residue_zero_based = int(mutation[1:-1]) - 1 new_variant.append(f"{mutation[0]}{residue_zero_based}{mutation[-1]}") new_variants.append(",".join(new_variant)) zero_based.append(new_variants) return zero_based def get_lines_from_multimodal(modal_output): if len(modal_output['text']) == 0 and len(modal_output['files']) == 0: return [] if len(modal_output['files']) == 0: return [modal_output['text']] text = open(modal_output['files'][0], 'r').readlines() text = [line.strip() for line in text] if len(modal_output['text']) > 0: text.append(modal_output['text']) return text def get_color(color, model_text): return f'{model_text}' def get_file(filepath: str): global pdb_path pdb_path = filepath print(filepath) iframe = get_iframe(None) return iframe def empty_pdb_path(button): global pdb_path pdb_path = None return "" def load_model(model_id, _): global metl if not isinstance(model_id, str): return get_color(RED, "Select Model"), gr.Button(interactive=False) if model_id.lower() in metl.config.IDENT_UUID_MAP: metl.load_from_ident(model_id) elif model_id in metl.config.UUID_URL_MAP: metl.load_from_uuid(model_id) else: return get_color(RED, "Model Load Failed"), gr.Button(interactive=False) return get_color(GREEN, f"{model_id} loaded"), gr.Button(interactive=True) def update_pdb(variant_modal, indexing): if len(variant_modal['text']) == 0 and len(variant_modal['files']) == 0: return gr.CheckboxGroup() variants = get_lines_from_multimodal(modal_output = variant_modal) if indexing == 1: variants = to_zero_based(variants)[0] else: variants = json.loads(variants[0]) print(variants) return gr.CheckboxGroup(choices=variants, value=variants, visible=True), get_iframe(variants=variants) def select_or_deselect_all(button_name, choices, variant_modal, indexing): if "De" in button_name: return gr.CheckboxGroup(visible=True, choices=choices, value=[]), get_iframe(variants=None) variants = get_lines_from_multimodal(modal_output = variant_modal) if indexing == 1: variants = to_zero_based(variants) variants = variants[0] checkbox, iframe = update_pdb(variant_modal=variant_modal, indexing=indexing) return checkbox, iframe def hide_variants(checkbox_values): return get_iframe(variants=checkbox_values) def populate_example(): global pdb_path model = "metl-l-2m-3d-gb1" wt = "MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE" variants = '["T17P,T54F", "V28L,F51A"]' # "T17P,V28L,F51A,T54F" pdb_path = './2qmt_p.pdb' status, pred_button = load_model(model, None) wt_dict = { "text": wt, "files": [] } variants_dict = { "text": variants, "files": [] } default_checkbox, iframe = update_pdb(variant_modal=variants_dict, indexing=0) return ( model, status, pred_button, wt_dict, variants_dict, iframe, pdb_path, default_checkbox, gr.Button(visible=True), gr.Button(visible=True) ) def hide_mutation(): pass def predict(input_multi_modal, variant_multi_modal, variant_index_type): global pdb_path global metl input_sequences = get_lines_from_multimodal(input_multi_modal) variants = get_lines_from_multimodal(variant_multi_modal) if len(input_sequences) == 0 or len(variants) == 0: err_out = "Invalid input. " if len(input_sequences) == 0: err_out += "Input sequences were not given, but a wild type must be given." if len(variants) == 0: err_out += "Mutations were not given, but mutations must be given in JSON array format to predict with METL" return err_out, get_iframe(None), gr.Button(interactive=False) try: if variant_index_type == 1: variants = to_zero_based(variants) else: variants = [json.loads(variant) for variant in variants] except: err_out = "One or more of the mutations given were not in a valid JSON list format" return err_out, get_iframe(None), gr.Button(interactive=False) metl.eval() outputs = [] sequence = input_sequences[0] for index, variant in enumerate(variants): if index >= 100: break encoded_variants = metl.encoder.encode_variants(sequence, variant) with torch.no_grad(): if pdb_path is not None: predictions = metl(torch.tensor(encoded_variants), pdb_fn=pdb_path) else: predictions = metl(torch.tensor(encoded_variants)) outputs.append({ "wt": sequence, "variants": variant, "logits": predictions.tolist() }) out_str = json.dumps(outputs) if len(outputs) > 1 else str(outputs[0]['logits']) variants_dict = { "text": json.dumps(variants[0]), "files": [] } # We do 0 for the indexing even though we are using 1 sometimes because we update it here already # We don't want update_pdb to double up on subtracting 1 from the index again checkbox, iframe = update_pdb(variant_modal=variants_dict, indexing=0) return out_str, iframe, checkbox with gr.Blocks(css=RADIO_CSS) as demo: with gr.Row(equal_height=True, elem_id="modelPDBRow"): with gr.Column(elem_id="modelInputCol"): metl_model_id = gr.Dropdown(label="METL model IDENT or UUID", choices=list(metl.config.IDENT_UUID_MAP.keys()), allow_custom_value=False) metl_model_status = gr.HTML(get_color(RED, "Select Model"), elem_id="modelStatus") with gr.Column(): upload_pdb = gr.File(label="PDB File upload", elem_id="pdbUpload", file_types=[".pdb", ".txt"]) with gr.Column(): metl_seq_input = gr.MultimodalTextbox(label="Input Protein Sequence", interactive=True, elem_classes="multiModalText", elem_id="wildTypeSequence") with gr.Row(): metl_variants = gr.MultimodalTextbox(label="JSON variant list", scale=100, interactive=True, elem_classes="multiModalText", file_types=[".json", ".txt"]) variant_indexing = gr.Radio(choices=[0, 1], elem_id=["indexing"], min_width=80, label='Indexing', value=0) metl_update_pdb_display = gr.Button("Update PDB display", min_width=100) metl_output = gr.TextArea(label="Output from METL", interactive=False, show_copy_button=True) help_text = gr.Markdown("Load a model and if necessary, upload a pdb file to get started. File inputs for mutations must be 1 JSON list per line. When a file is uploaded with multiple mutations, the first mutation will be displayed on the molecule.") metl_run_button = gr.Button("Run Prediction", interactive=False) metl_load_example = gr.Button("Load Example") with gr.Row(elem_id="moleculeRow"): molecule = gr.HTML() with gr.Column(scale=10): with gr.Row(): select_all = gr.Button(elem_classes=["selectorButton"], value="Deselect all", visible=False) deselect_all = gr.Button(elem_classes=["selectorButton"], value="Select all", visible=False) show_hide_variants = gr.CheckboxGroup(show_label=False, visible=False, elem_id="variantCheck") with Modal(visible=False) as help_modal: modal_text = """ This is a demo for [METL](https://huggingface.co/gitter-lab/METL). The supported METL models are listed in the dropdown. The specifics of each of these models may be found in the above 🤗 link or [here](https://github.com/gitter-lab/metl?tab=readme-ov-file). To run this demo, follow these steps: 1. select a model in the provided dropdown. 2. upload a pdb file if it is required for your prediction. 3. paste in your wild type sequence 4. paste in your mutations in JSON list format, where each mutation in the list is a CSV string separated by double quotes ("). - an exapmle is provided when the "Load Example" button is pressed. - if a PDB is given, and mutations are in the corresponding text box then update button may be used to display those mutations in an interactive molecule display using 3Dmol.js. 5. Press "Run Prediction" For cases where many combinations of mutations are given, uploading a text file where each line is a new JSON list (as described above) will allow up to 100 different METL predictions! """ gr.Markdown(modal_text, elem_id="helpModalText") help_alert = gr.Button("Help!", elem_id="helpModal") help_alert.click(lambda: Modal(visible=True), None, help_modal) ## Model and PDB event handlers metl_model_id.input(fn=load_model, inputs=[metl_model_id, metl_model_id], outputs=[metl_model_status, metl_run_button]) upload_pdb.clear() upload_pdb.upload(fn=get_file, inputs=upload_pdb, outputs=molecule, show_progress=False) metl_update_pdb_display.click(fn=update_pdb, inputs=[metl_variants, variant_indexing], outputs=[show_hide_variants, molecule], show_progress=False) ## Predicting event handlers metl_run_button.click(fn=predict, inputs=[metl_seq_input, metl_variants, variant_indexing], outputs=[metl_output, molecule, show_hide_variants], show_progress=False) ## Load example metl_load_example.click(fn=populate_example, outputs=[metl_model_id, metl_model_status, metl_run_button, metl_seq_input, metl_variants, molecule, upload_pdb, show_hide_variants, select_all, deselect_all], show_progress=False) ## Event handlers for the molcule display select_all.click(fn=select_or_deselect_all, inputs=[select_all, show_hide_variants, metl_variants, variant_indexing], outputs=[show_hide_variants, molecule], show_progress=False) deselect_all.click(fn=select_or_deselect_all, inputs=[deselect_all, show_hide_variants, metl_variants, variant_indexing], outputs=[show_hide_variants, molecule], show_progress=False) show_hide_variants.input(fn=hide_variants, inputs=show_hide_variants, outputs=molecule, show_progress=False) demo.launch() # test model: METL-L-2M-3D-GB1 # test wild type: MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE # test variants: ["T17P,T54F", "V28L,F51A", "T17P,V28L,F51A,T54F"]