Spaces:
Sleeping
Sleeping
| from datetime import datetime | |
| import gradio as gr | |
| import requests | |
| from Bio.PDB import PDBParser, MMCIFParser, PDBIO, Select | |
| from Bio.PDB.Polypeptide import is_aa | |
| from Bio.SeqUtils import seq1 | |
| from typing import Optional, Tuple | |
| import numpy as np | |
| import os | |
| from gradio_molecule3d import Molecule3D | |
| from model_loader import load_model | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| import re | |
| import pandas as pd | |
| import copy | |
| import transformers | |
| from transformers import AutoTokenizer, DataCollatorForTokenClassification | |
| from datasets import Dataset | |
| from scipy.special import expit | |
| # Load model and move to device | |
| #checkpoint = 'ThorbenF/prot_t5_xl_uniref50' | |
| #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_cryptic' | |
| #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_database' | |
| checkpoint = 'ThorbenF/prot_t5_xl_uniref50_full' | |
| max_length = 1500 | |
| model, tokenizer = load_model(checkpoint, max_length) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model.to(device) | |
| model.eval() | |
| def normalize_scores(scores): | |
| min_score = np.min(scores) | |
| max_score = np.max(scores) | |
| return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores | |
| def read_mol(pdb_path): | |
| """Read PDB file and return its content as a string""" | |
| with open(pdb_path, 'r') as f: | |
| return f.read() | |
| def fetch_structure(pdb_id: str, output_dir: str = ".") -> str: | |
| """ | |
| Fetch the structure file for a given PDB ID. Prioritizes CIF files. | |
| If a structure file already exists locally, it uses that. | |
| """ | |
| file_path = download_structure(pdb_id, output_dir) | |
| return file_path | |
| def download_structure(pdb_id: str, output_dir: str) -> str: | |
| """ | |
| Attempt to download the structure file in CIF or PDB format. | |
| Returns the path to the downloaded file. | |
| """ | |
| for ext in ['.cif', '.pdb']: | |
| file_path = os.path.join(output_dir, f"{pdb_id}{ext}") | |
| if os.path.exists(file_path): | |
| return file_path | |
| url = f"https://files.rcsb.org/download/{pdb_id}{ext}" | |
| response = requests.get(url, timeout=10) | |
| if response.status_code == 200: | |
| with open(file_path, 'wb') as f: | |
| f.write(response.content) | |
| return file_path | |
| return None | |
| def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str: | |
| """ | |
| Convert a CIF file to PDB format using BioPython and return the PDB file path. | |
| """ | |
| pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb')) | |
| parser = MMCIFParser(QUIET=True) | |
| structure = parser.get_structure('protein', cif_path) | |
| io = PDBIO() | |
| io.set_structure(structure) | |
| io.save(pdb_path) | |
| return pdb_path | |
| def fetch_pdb(pdb_id): | |
| pdb_path = fetch_structure(pdb_id) | |
| _, ext = os.path.splitext(pdb_path) | |
| if ext == '.cif': | |
| pdb_path = convert_cif_to_pdb(pdb_path) | |
| return pdb_path | |
| def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str: | |
| """ | |
| Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores | |
| """ | |
| parser = PDBParser(QUIET=True) | |
| structure = parser.get_structure('protein', input_pdb) | |
| output_pdb = f"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb" | |
| # Create scores dictionary for easy lookup | |
| scores_dict = {resi: score for resi, score in residue_scores} | |
| # Create a custom Select class | |
| class ResidueSelector(Select): | |
| def __init__(self, chain_id, selected_residues, scores_dict): | |
| self.chain_id = chain_id | |
| self.selected_residues = selected_residues | |
| self.scores_dict = scores_dict | |
| def accept_chain(self, chain): | |
| return chain.id == self.chain_id | |
| def accept_residue(self, residue): | |
| return residue.id[1] in self.selected_residues | |
| def accept_atom(self, atom): | |
| if atom.parent.id[1] in self.scores_dict: | |
| atom.bfactor = np.absolute(1-self.scores_dict[atom.parent.id[1]]) * 100 | |
| return True | |
| # Prepare output PDB with selected chain and residues, modified B-factors | |
| io = PDBIO() | |
| selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict) | |
| io.set_structure(structure[0]) | |
| io.save(output_pdb, selector) | |
| return output_pdb | |
| def generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, score_type): | |
| """Generate PyMOL commands based on score type""" | |
| pymol_commands = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\nScore Type: {score_type}\n\n" | |
| pymol_commands += f""" | |
| # PyMOL Visualization Commands | |
| fetch {pdb_id}, protein | |
| hide everything, all | |
| show cartoon, chain {segment} | |
| color white, chain {segment} | |
| """ | |
| # Define colors for each score bracket | |
| bracket_colors = { | |
| "0.0-0.2": "white", | |
| "0.2-0.4": "lightorange", | |
| "0.4-0.6": "yelloworange", | |
| "0.6-0.8": "orange", | |
| "0.8-1.0": "red" | |
| } | |
| # Add PyMOL commands for each score bracket | |
| for bracket, residues in residues_by_bracket.items(): | |
| if residues: # Only add commands if there are residues in this bracket | |
| color = bracket_colors[bracket] | |
| resi_list = '+'.join(map(str, residues)) | |
| pymol_commands += f""" | |
| select bracket_{bracket.replace('.', '').replace('-', '_')}, resi {resi_list} and chain {segment} | |
| show sticks, bracket_{bracket.replace('.', '').replace('-', '_')} | |
| color {color}, bracket_{bracket.replace('.', '').replace('-', '_')} | |
| """ | |
| return pymol_commands | |
| def generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence, scores, current_time, score_type): | |
| """Generate results text based on score type""" | |
| result_str = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\nScore Type: {score_type}\n\n" | |
| result_str += "Residues by Score Brackets:\n\n" | |
| # Add residues for each bracket | |
| for bracket, residues in residues_by_bracket.items(): | |
| result_str += f"Bracket {bracket}:\n" | |
| result_str += f"Columns: Residue Name, Residue Number, One-letter Code, {score_type} Score\n" | |
| result_str += "\n".join([ | |
| f"{res.resname} {res.id[1]} {sequence[i]} {scores[i]:.2f}" | |
| for i, res in enumerate(protein_residues) if res.id[1] in residues | |
| ]) | |
| result_str += "\n\n" | |
| return result_str | |
| def process_pdb(pdb_id_or_file, segment, score_type='normalized'): | |
| # Determine if input is a PDB ID or file path | |
| if pdb_id_or_file.endswith('.pdb'): | |
| pdb_path = pdb_id_or_file | |
| pdb_id = os.path.splitext(os.path.basename(pdb_path))[0] | |
| else: | |
| pdb_id = pdb_id_or_file | |
| pdb_path = fetch_pdb(pdb_id) | |
| # Determine the file format and choose the appropriate parser | |
| _, ext = os.path.splitext(pdb_path) | |
| parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True) | |
| # Parse the structure file | |
| structure = parser.get_structure('protein', pdb_path) | |
| # Extract the specified chain | |
| chain = structure[0][segment] | |
| protein_residues = [res for res in chain if is_aa(res)] | |
| sequence = "".join(seq1(res.resname) for res in protein_residues) | |
| sequence_id = [res.id[1] for res in protein_residues] | |
| input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids).logits.detach().cpu().numpy().squeeze() | |
| # Calculate scores and normalize them | |
| raw_scores = expit(outputs[:, 1] - outputs[:, 0]) | |
| normalized_scores = normalize_scores(raw_scores) | |
| # Choose which scores to use based on score_type | |
| display_scores = normalized_scores if score_type == 'normalized' else raw_scores | |
| # Zip residues with scores to track the residue ID and score | |
| residue_scores = [(resi, score) for resi, score in zip(sequence_id, display_scores)] | |
| # Also save both score types for later use | |
| raw_residue_scores = [(resi, score) for resi, score in zip(sequence_id, raw_scores)] | |
| norm_residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)] | |
| # Define the score brackets | |
| score_brackets = { | |
| "0.0-0.2": (0.0, 0.2), | |
| "0.2-0.4": (0.2, 0.4), | |
| "0.4-0.6": (0.4, 0.6), | |
| "0.6-0.8": (0.6, 0.8), | |
| "0.8-1.0": (0.8, 1.0) | |
| } | |
| # Initialize a dictionary to store residues by bracket | |
| residues_by_bracket = {bracket: [] for bracket in score_brackets} | |
| # Categorize residues into brackets | |
| for resi, score in residue_scores: | |
| for bracket, (lower, upper) in score_brackets.items(): | |
| if lower <= score < upper: | |
| residues_by_bracket[bracket].append(resi) | |
| break | |
| # Generate timestamp | |
| current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| # Generate result text and PyMOL commands based on score type | |
| display_score_type = "Normalized" if score_type == 'normalized' else "Raw" | |
| result_str = generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence, | |
| display_scores, current_time, display_score_type) | |
| pymol_commands = generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, display_score_type) | |
| # Create chain-specific PDB with scores in B-factor | |
| scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues) | |
| # Molecule visualization with updated script with color mapping | |
| mol_vis = molecule(pdb_path, residue_scores, segment) | |
| # Create prediction file | |
| prediction_file = f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt" | |
| with open(prediction_file, "w") as f: | |
| f.write(result_str) | |
| scored_pdb_name = f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb" | |
| os.rename(scored_pdb, scored_pdb_name) | |
| return pymol_commands, mol_vis, [prediction_file, scored_pdb_name], raw_residue_scores, norm_residue_scores, pdb_id, segment | |
| def molecule(input_pdb, residue_scores=None, segment='A'): | |
| # Read PDB file content | |
| mol = read_mol(input_pdb) | |
| # Prepare high-scoring residues script if scores are provided | |
| high_score_script = "" | |
| if residue_scores is not None: | |
| # Filter residues based on their scores | |
| class1_score_residues = [resi for resi, score in residue_scores if 0.0 < score <= 0.2] | |
| class2_score_residues = [resi for resi, score in residue_scores if 0.2 < score <= 0.4] | |
| class3_score_residues = [resi for resi, score in residue_scores if 0.4 < score <= 0.6] | |
| class4_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.8] | |
| class5_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 1.0] | |
| high_score_script = """ | |
| // Load the original model and apply white cartoon style | |
| let chainModel = viewer.addModel(pdb, "pdb"); | |
| chainModel.setStyle({}, {}); | |
| chainModel.setStyle( | |
| {"chain": "%s"}, | |
| {"cartoon": {"color": "white"}} | |
| ); | |
| // Create a new model for high-scoring residues and apply red sticks style | |
| let class1Model = viewer.addModel(pdb, "pdb"); | |
| class1Model.setStyle({}, {}); | |
| class1Model.setStyle( | |
| {"chain": "%s", "resi": [%s]}, | |
| {"stick": {"color": "0xFFFFFF", "opacity": 0.5}} | |
| ); | |
| // Create a new model for high-scoring residues and apply red sticks style | |
| let class2Model = viewer.addModel(pdb, "pdb"); | |
| class2Model.setStyle({}, {}); | |
| class2Model.setStyle( | |
| {"chain": "%s", "resi": [%s]}, | |
| {"stick": {"color": "0xFFD580", "opacity": 0.7}} | |
| ); | |
| // Create a new model for high-scoring residues and apply red sticks style | |
| let class3Model = viewer.addModel(pdb, "pdb"); | |
| class3Model.setStyle({}, {}); | |
| class3Model.setStyle( | |
| {"chain": "%s", "resi": [%s]}, | |
| {"stick": {"color": "0xFFA500", "opacity": 1}} | |
| ); | |
| // Create a new model for high-scoring residues and apply red sticks style | |
| let class4Model = viewer.addModel(pdb, "pdb"); | |
| class4Model.setStyle({}, {}); | |
| class4Model.setStyle( | |
| {"chain": "%s", "resi": [%s]}, | |
| {"stick": {"color": "0xFF4500", "opacity": 1}} | |
| ); | |
| // Create a new model for high-scoring residues and apply red sticks style | |
| let class5Model = viewer.addModel(pdb, "pdb"); | |
| class5Model.setStyle({}, {}); | |
| class5Model.setStyle( | |
| {"chain": "%s", "resi": [%s]}, | |
| {"stick": {"color": "0xFF0000", "alpha": 1}} | |
| ); | |
| """ % ( | |
| segment, | |
| segment, | |
| ", ".join(str(resi) for resi in class1_score_residues), | |
| segment, | |
| ", ".join(str(resi) for resi in class2_score_residues), | |
| segment, | |
| ", ".join(str(resi) for resi in class3_score_residues), | |
| segment, | |
| ", ".join(str(resi) for resi in class4_score_residues), | |
| segment, | |
| ", ".join(str(resi) for resi in class5_score_residues) | |
| ) | |
| # Generate the full HTML content | |
| html_content = f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <meta http-equiv="content-type" content="text/html; charset=UTF-8" /> | |
| <style> | |
| .mol-container {{ | |
| width: 100%; | |
| height: 700px; | |
| position: relative; | |
| }} | |
| </style> | |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js"></script> | |
| <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> | |
| </head> | |
| <body> | |
| <div id="container" class="mol-container"></div> | |
| <script> | |
| let pdb = `{mol}`; // Use template literal to properly escape PDB content | |
| $(document).ready(function () {{ | |
| let element = $("#container"); | |
| let config = {{ backgroundColor: "white" }}; | |
| let viewer = $3Dmol.createViewer(element, config); | |
| {high_score_script} | |
| // Add hover functionality | |
| viewer.setHoverable( | |
| {{}}, | |
| true, | |
| function(atom, viewer, event, container) {{ | |
| if (!atom.label) {{ | |
| atom.label = viewer.addLabel( | |
| atom.resn + ":" +atom.resi + ":" + atom.atom, | |
| {{ | |
| position: atom, | |
| backgroundColor: 'mintcream', | |
| fontColor: 'black', | |
| fontSize: 18, | |
| padding: 4 | |
| }} | |
| ); | |
| }} | |
| }}, | |
| function(atom, viewer) {{ | |
| if (atom.label) {{ | |
| viewer.removeLabel(atom.label); | |
| delete atom.label; | |
| }} | |
| }} | |
| ); | |
| viewer.zoomTo(); | |
| viewer.render(); | |
| viewer.zoom(0.8, 2000); | |
| }}); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| # Return the HTML content within an iframe safely encoded for special characters | |
| return f'<iframe width="100%" height="700" srcdoc="{html_content.replace(chr(34), """).replace(chr(39), "'")}"></iframe>' | |
| with gr.Blocks(css=""" | |
| /* Customize Gradio button colors */ | |
| #visualize-btn, #predict-btn { | |
| background-color: #FF7300; /* Deep orange */ | |
| color: white; | |
| border-radius: 5px; | |
| padding: 10px; | |
| font-weight: bold; | |
| } | |
| #visualize-btn:hover, #predict-btn:hover { | |
| background-color: #CC5C00; /* Darkened orange on hover */ | |
| } | |
| """) as demo: | |
| gr.Markdown("# Protein Binding Site Prediction") | |
| # Mode selection | |
| mode = gr.Radio( | |
| choices=["PDB ID", "Upload File"], | |
| value="PDB ID", | |
| label="Input Mode", | |
| info="Choose whether to input a PDB ID or upload a PDB/CIF file." | |
| ) | |
| # Input components based on mode | |
| pdb_input = gr.Textbox(value="2F6V", label="PDB ID", placeholder="Enter PDB ID here...") | |
| pdb_file = gr.File(label="Upload PDB/CIF File", visible=False) | |
| visualize_btn = gr.Button("Visualize Structure", elem_id="visualize-btn") | |
| molecule_output2 = Molecule3D(label="Protein Structure", reps=[ | |
| { | |
| "model": 0, | |
| "style": "cartoon", | |
| "color": "whiteCarbon", | |
| "residue_range": "", | |
| "around": 0, | |
| "byres": False, | |
| } | |
| ]) | |
| with gr.Row(): | |
| segment_input = gr.Textbox(value="A", label="Chain ID (protein)", placeholder="Enter Chain ID here...", | |
| info="Choose in which chain to predict binding sites.") | |
| prediction_btn = gr.Button("Predict Binding Site", elem_id="predict-btn") | |
| # Add score type selector | |
| score_type = gr.Radio( | |
| choices=["Normalized Scores", "Raw Scores"], | |
| value="Normalized Scores", | |
| label="Score Visualization Type", | |
| info="Choose which score type to visualize" | |
| ) | |
| molecule_output = gr.HTML(label="Protein Structure") | |
| explanation_vis = gr.Markdown(""" | |
| Score dependent colorcoding: | |
| - 0.0-0.2: white | |
| - 0.2–0.4: light orange | |
| - 0.4–0.6: yellow orange | |
| - 0.6–0.8: orange | |
| - 0.8–1.0: red | |
| """) | |
| predictions_output = gr.Textbox(label="Visualize Prediction with PyMol") | |
| gr.Markdown("### Download:\n- List of predicted binding site residues\n- PDB with score in beta factor column") | |
| download_output = gr.File(label="Download Files", file_count="multiple") | |
| # Store these as state variables so we can switch between them | |
| raw_scores_state = gr.State(None) | |
| norm_scores_state = gr.State(None) | |
| last_pdb_path = gr.State(None) | |
| last_segment = gr.State(None) | |
| last_pdb_id = gr.State(None) | |
| def process_interface(mode, pdb_id, pdb_file, chain_id, score_type_val): | |
| selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw' | |
| # First get the actual PDB file path | |
| if mode == "PDB ID": | |
| pdb_path = fetch_pdb(pdb_id) # Get the actual file path | |
| pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type) | |
| # Store the actual file path, not just the PDB ID | |
| return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result | |
| elif mode == "Upload File": | |
| _, ext = os.path.splitext(pdb_file.name) | |
| file_path = os.path.join('./', f"{_}{ext}") | |
| if ext == '.cif': | |
| pdb_path = convert_cif_to_pdb(file_path) | |
| else: | |
| pdb_path = file_path | |
| pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type) | |
| return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result | |
| def update_visualization_and_files(score_type_val, raw_scores, norm_scores, pdb_path, segment, pdb_id): | |
| if raw_scores is None or norm_scores is None or pdb_path is None or segment is None or pdb_id is None: | |
| return None, None, None | |
| # Choose scores based on radio button selection | |
| selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw' | |
| selected_scores = norm_scores if selected_score_type == 'normalized' else raw_scores | |
| # Generate visualization with selected scores | |
| mol_vis = molecule(pdb_path, selected_scores, segment) | |
| # Generate PyMOL commands and downloadable files | |
| # Get structure for residue info | |
| _, ext = os.path.splitext(pdb_path) | |
| parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True) | |
| structure = parser.get_structure('protein', pdb_path) | |
| chain = structure[0][segment] | |
| protein_residues = [res for res in chain if is_aa(res)] | |
| sequence = "".join(seq1(res.resname) for res in protein_residues) | |
| # Define score brackets | |
| score_brackets = { | |
| "0.0-0.2": (0.0, 0.2), | |
| "0.2-0.4": (0.2, 0.4), | |
| "0.4-0.6": (0.4, 0.6), | |
| "0.6-0.8": (0.6, 0.8), | |
| "0.8-1.0": (0.8, 1.0) | |
| } | |
| # Initialize a dictionary to store residues by bracket | |
| residues_by_bracket = {bracket: [] for bracket in score_brackets} | |
| # Categorize residues into brackets | |
| for resi, score in selected_scores: | |
| for bracket, (lower, upper) in score_brackets.items(): | |
| if lower <= score < upper: | |
| residues_by_bracket[bracket].append(resi) | |
| break | |
| # Generate timestamp | |
| current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| # Generate result text and PyMOL commands based on score type | |
| display_score_type = "Normalized" if selected_score_type == 'normalized' else "Raw" | |
| scores_array = [score for _, score in selected_scores] | |
| result_str = generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence, | |
| scores_array, current_time, display_score_type) | |
| pymol_commands = generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, display_score_type) | |
| # Create chain-specific PDB with scores in B-factor | |
| scored_pdb = create_chain_specific_pdb(pdb_path, segment, selected_scores, protein_residues) | |
| # Create prediction file | |
| prediction_file = f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt" | |
| with open(prediction_file, "w") as f: | |
| f.write(result_str) | |
| scored_pdb_name = f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb" | |
| os.rename(scored_pdb, scored_pdb_name) | |
| return mol_vis, pymol_commands, [prediction_file, scored_pdb_name] | |
| def fetch_interface(mode, pdb_id, pdb_file): | |
| if mode == "PDB ID": | |
| return fetch_pdb(pdb_id) | |
| elif mode == "Upload File": | |
| _, ext = os.path.splitext(pdb_file.name) | |
| file_path = os.path.join('./', f"{_}{ext}") | |
| if ext == '.cif': | |
| pdb_path = convert_cif_to_pdb(file_path) | |
| else: | |
| pdb_path= file_path | |
| return pdb_path | |
| def toggle_mode(selected_mode): | |
| if selected_mode == "PDB ID": | |
| return gr.update(visible=True), gr.update(visible=False) | |
| else: | |
| return gr.update(visible=False), gr.update(visible=True) | |
| mode.change( | |
| toggle_mode, | |
| inputs=[mode], | |
| outputs=[pdb_input, pdb_file] | |
| ) | |
| prediction_btn.click( | |
| process_interface, | |
| inputs=[mode, pdb_input, pdb_file, segment_input, score_type], | |
| outputs=[predictions_output, molecule_output, download_output, | |
| raw_scores_state, norm_scores_state, last_pdb_path, last_segment, last_pdb_id] | |
| ) | |
| # Update visualization, PyMOL commands, and files when score type changes | |
| score_type.change( | |
| update_visualization_and_files, | |
| inputs=[score_type, raw_scores_state, norm_scores_state, last_pdb_path, last_segment, last_pdb_id], | |
| outputs=[molecule_output, predictions_output, download_output] | |
| ) | |
| visualize_btn.click( | |
| fetch_interface, | |
| inputs=[mode, pdb_input, pdb_file], | |
| outputs=molecule_output2 | |
| ) | |
| gr.Markdown("## Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["7RPZ", "A"], | |
| ["2IWI", "B"], | |
| ["7LCJ", "R"], | |
| ["4OBE", "A"] | |
| ], | |
| inputs=[pdb_input, segment_input], | |
| outputs=[predictions_output, molecule_output, download_output] | |
| ) | |
| demo.launch(share=True) |