Spaces:
Sleeping
Sleeping
import asyncio | |
import queue | |
import threading | |
import spaces | |
import io | |
import os | |
import tempfile | |
import uuid | |
import zipfile | |
from datetime import datetime | |
from pathlib import Path | |
from time import time | |
from dotenv import load_dotenv | |
import torch | |
from email_validator import validate_email, EmailNotValidError | |
from Bio import SeqIO | |
import gradio as gr | |
from gradio_rangeslider import RangeSlider | |
from omegaconf import OmegaConf | |
import pandas as pd | |
from rdkit import Chem | |
from inference import (read_fragment_library, process_fragment_library, extract_pockets, | |
dock_fragments, generate_linkers, select_fragment_pairs) | |
from app import static, fn, db | |
load_dotenv() | |
RESULTS_DIR = os.getenv('RESULTS', 'results') | |
MAX_CONCURRENT_JOBS = 1 | |
task_queue = queue.Queue() | |
semaphore = threading.Semaphore(MAX_CONCURRENT_JOBS) | |
lock = threading.Lock() | |
job_db = db.init_job_db() | |
Path(tempfile.gettempdir(), 'gradio').mkdir(exist_ok=True) | |
gr.set_static_paths(paths=["data/", RESULTS_DIR, "app/"]) | |
os.chmod('./fpocket', 0o755) | |
def task_worker(): | |
"""Worker function to process tasks from the queue with concurrency limit.""" | |
while True: | |
data = task_queue.get() # Get the next task from the queue | |
with semaphore: # Ensure only 'MAX_CONCURRENT_JOBS' tasks run at once | |
with lock: # Ensure only one task is processed at a time (for shared state) | |
dock_link(*data) | |
task_queue.task_done() | |
worker_threads = [] | |
for _ in range(MAX_CONCURRENT_JOBS): | |
worker_thread = threading.Thread(target=task_worker, daemon=True) | |
worker_threads.append(worker_thread) | |
worker_thread.start() | |
FRAG_LIBS = {'': None} | { | |
lib_path.stem.replace('_', ' '): str(lib_path) for lib_path in Path('data/fragment_libraries').glob('*') | |
} | |
FRAG_LIB_PROCESS_OPTS = { | |
'Dehalogenate Fragments': 'dehalogenate', | |
'Discard Inorganic Fragments': 'discard_inorganic' | |
} | |
POCKET_EXTRACT_OPTS = { | |
'Topological Prediction with Fpocket': { | |
'name': 'fpocket', | |
'info': 'If your protein structure contains co-crystallized ligands, you may CLICK ON ' | |
'the ligand with your desired binding pose to predict its corresponding pocket. ' | |
'Otherwise, pockets will be predicted based on the protein structure alone. After extracting ' | |
'the pocket(s), CLICK ON your desired pocket to SELECT ONE for fragment linking.', | |
'params': {} | |
}, | |
'Fragment Conformer Clustering': { | |
'name': 'clustering', | |
'info': 'Conformers of docked fragments will be clustered based on their spatial similarity, and conformers ' | |
'within a cluster will be selected for linking. This strategy takes delayed effect AFTER DOCKING.' | |
} | |
} | |
def gr_error_wrapper(func): | |
def wrapper(*args, **kwargs): | |
try: | |
return func(*args, **kwargs) | |
except Exception as e: | |
raise gr.Error(str(e)) | |
return wrapper | |
async def query_job_status(job_id): | |
stop = False | |
interval = 3 # Check every 3 seconds for better responsiveness | |
retry = 0 | |
while not stop: | |
# Wait for a short interval before checking again | |
await asyncio.sleep(interval) # Non-blocking sleep | |
job = job_db.job_lookup(job_id) | |
if job: # If the job exists | |
if job['status'] == "RUNNING": # If the job is still running | |
yield { | |
pred_lookup_status: f''' | |
Your job (ID: **{job['id']}**) started at **{job['start_time']}** and is **RUNNING...** | |
It might take a few minutes to a few hours depending on the input size and the queue status. | |
You may keep the page open or close it and revisit later using the job ID. | |
You will receive an email notification once the job is done. | |
''', | |
pred_lookup_btn: gr.update(visible=False), | |
pred_lookup_stop_btn: gr.update(visible=True), | |
} | |
elif job['status'] == "COMPLETED": # If the job is complete | |
stop = True | |
msg = f"Your job (ID: {job['id']}) has been **COMPLETED**" | |
msg += f" at **{job['end_time']}**" if job.get('end_time') else "" | |
msg += f" and the results will **EXPIRE** by **{job['expiry_time']}**." if job.get('expiry_time') else "." | |
msg += " Redirecting to the results page..." | |
yield { | |
pred_lookup_status: msg, | |
tabs: gr.Tabs(selected='result'), | |
result_state: job, | |
pred_lookup_btn: gr.update(visible=True), | |
pred_lookup_stop_btn: gr.update(visible=False), | |
} | |
elif job['status'] == "FAILED": # If the job failed | |
stop = True | |
msg = f'Your job (ID: {job_id}) has **FAILED**' | |
msg += f" at {job['end_time']}" if job.get('end_time') else '' | |
msg += f" due to error: {job['error']}." if job.get('error') else '.' | |
yield { | |
pred_lookup_status: msg, | |
pred_lookup_btn: gr.update(visible=True), | |
pred_lookup_stop_btn: gr.update(visible=False), | |
} | |
else: # If the job is not found | |
stop = retry > 2 # Stop after 3 retries | |
if not stop: | |
msg = f'Job ID {job_id} not found. Retrying... ({retry})' | |
else: | |
msg = f'Job ID {job_id} not found after {retry} retries. Please double-check the job ID.' | |
retry += 1 | |
yield { | |
pred_lookup_status: msg, | |
pred_lookup_btn: gr.update(visible=stop), | |
pred_lookup_stop_btn: gr.update(visible=not stop), | |
} | |
def checkbox_group_selections_to_kwargs(selected_options, option_mapping): | |
kwargs = { | |
option_mapping[label]: label in selected_options | |
for label in option_mapping | |
} | |
return kwargs | |
def job_submit( | |
frag_df, frag_file, prot_file, | |
dock_n_steps, dock_n_poses, dock_confidence_threshold, | |
linker_frag_dist, linker_strategy, linker_n_mols, linker_size, linker_steps, | |
pocket_name, pocket_method, pocket_fs, | |
email, session_info: gr.Request | |
): | |
if len(frag_df) == 0 or not frag_file: | |
raise gr.Error("Please provide a valid fragment library.") | |
if not prot_file: | |
raise gr.Error("Please provide a valid protein structure.") | |
pocket_extraction_method = POCKET_EXTRACT_OPTS[pocket_method]['name'] | |
pocket_path_dict = {} | |
if pocket_extraction_method == 'fpocket': | |
if not pocket_name or not pocket_fs: | |
raise gr.Error("If you wish to use a protein pocket predicted by Fpocket, " | |
"please select a pocket after clicking on 'Extract Pocket'.") | |
else: | |
for pocket_file in pocket_fs: | |
if Path(pocket_file).stem.startswith(pocket_name): | |
pocket_path_dict[pocket_name] = pocket_file | |
if email: | |
try: | |
email_info = validate_email(email, check_deliverability=False) | |
email = email_info.normalized | |
except EmailNotValidError as e: | |
raise gr.Error(f"Invalid email address: {str(e)}.") | |
if check := job_db.check_user_running_job(email, session_info): | |
raise gr.Error(check) | |
gr.Info('Finished processing inputs. Initiating the GenFBDD job... You will be redirected to Job Status page.') | |
job_id = str(uuid.uuid4()) | |
job_info = { | |
'id': job_id, | |
'status': 'QUEUED', | |
'fragment_library_file': frag_file, | |
'protein_structure_file': prot_file, | |
'pocket_extraction_method': pocket_extraction_method, | |
'protein_pocket_files': pocket_path_dict, | |
'email': email, | |
'ip': session_info.headers.get('x-forwarded-for', session_info.client.host), | |
'cookies': dict(session_info.cookies), | |
'start_time': time(), | |
'end_time': None, | |
'expiry_time': None, | |
'error': None | |
} | |
job_db.insert(job_info) | |
task_queue.put(( | |
frag_df, prot_file, | |
dock_n_steps, dock_n_poses, dock_confidence_threshold, | |
linker_frag_dist, linker_strategy, linker_n_mols, linker_size, linker_steps, | |
job_info | |
)) | |
return { | |
pred_lookup_id: job_id, | |
tabs: gr.Tabs(selected='job'), | |
} | |
def dock_link( | |
frag_lib, prot, | |
dock_n_steps, dock_n_poses, dock_confidence_threshold, | |
linker_frag_dist, linker_strategy, linker_n_mols, linker_size, linker_steps, | |
job_info | |
): | |
job_id = job_info['id'] | |
job_db.job_update( | |
job_id=job_id, | |
update_info={'status': 'RUNNING'}, | |
) | |
pocket_extract_method = job_info['pocket_extraction_method'] | |
pocket_path_dict = job_info['protein_pocket_files'] | |
update_info = {} | |
config = OmegaConf.load('configs/gen_fbdd_v1.yaml') | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f'Using device: {device}') | |
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
out_dir = Path(RESULTS_DIR, f'{date_time}_{job_id}') | |
frag_lib['X2'] = prot | |
frag_lib['ID2'] = str(Path(prot).stem) | |
try: | |
docking_df = dock_fragments( | |
df=frag_lib, out_dir=out_dir, | |
score_ckpt=config.score_ckpt, confidence_ckpt=config.confidence_ckpt, | |
inference_steps=dock_n_steps, n_poses=dock_n_poses, | |
docking_batch_size=config.docking_batch_size, | |
initial_noise_std_proportion=config.initial_noise_std_proportion, | |
no_final_step_noise=config.no_final_step_noise, | |
temp_sampling_tr=config.temp_sampling_tr, | |
temp_sampling_rot=config.temp_sampling_rot, | |
temp_sampling_tor=config.temp_sampling_tor, | |
temp_psi_tr=config.temp_psi_tr, | |
temp_psi_rot=config.temp_psi_rot, | |
temp_psi_tor=config.temp_psi_tor, | |
temp_sigma_data_tr=config.temp_sigma_data_tr, | |
temp_sigma_data_rot=config.temp_sigma_data_rot, | |
temp_sigma_data_tor=config.temp_sigma_data_tor, | |
save_docking=True, | |
device=device, | |
) | |
linking_df = select_fragment_pairs( | |
docking_df, | |
method=pocket_extract_method, | |
pocket_path_dict=pocket_path_dict, | |
frag_dist_range=linker_frag_dist, | |
confidence_threshold=dock_confidence_threshold, | |
rmsd_threshold=1.5, | |
out_dir=out_dir, | |
) | |
if linking_df is not None and len(linking_df) > 0: | |
# Generate linkers | |
generate_linkers( | |
linking_df, | |
backbone_atoms_only=True, | |
output_dir=out_dir, | |
n_samples=linker_n_mols, | |
n_steps=linker_steps, | |
linker_size=linker_size, | |
anchors=None, | |
max_batch_size=config.linker_batch_size, | |
random_seed=None, | |
robust=False, | |
linker_ckpt=config.linker_ckpt, | |
size_ckpt=config.size_ckpt, | |
linker_condition=None, | |
device=device, | |
) | |
job_type = 'linking' | |
else: | |
gr.Warning('No fragment-conformer pairs found for linking. Please adjust the docking / linking settings.') | |
job_type = 'docking' | |
update_info = { | |
'status': "COMPLETED", | |
'error': None, | |
'output_dir': str(out_dir), | |
'type': job_type, | |
} | |
except Exception as e: | |
gr.Warning(f"Job failed due to error: {str(e)}") | |
update_info = { | |
'status': "FAILED", | |
'error': str(e), | |
'output_dir': None | |
} | |
finally: | |
job_db.job_update( | |
job_id=job_id, | |
update_info=update_info | |
) | |
THEME = gr.themes.Base( | |
spacing_size="sm", text_size='md', font=gr.themes.GoogleFont("Roboto"), | |
primary_hue='emerald', secondary_hue='emerald', neutral_hue='slate', | |
).set( | |
# body_background_fill='*primary_50' | |
# background_fill_primary='#eef3f9', | |
# background_fill_secondary='white', | |
# checkbox_label_background_fill='#eef3f9', | |
# checkbox_label_background_fill_hover='#dfe6f0', | |
# checkbox_background_color='white', | |
# checkbox_border_color='#4372c4', | |
# border_color_primary='#4372c4', | |
# border_color_accent='#2e6ab5', | |
# button_primary_background_fill='#2e6ab4', | |
# button_primary_text_color='white', | |
# body_text_color='#28496F', | |
# block_background_fill='#fbfcfd', | |
# block_title_text_color='#28496F', | |
# block_label_text_color='#28496F', | |
# block_info_text_color='#505358', | |
# block_border_color=None, | |
# input_border_color='#4372c4', | |
# panel_border_color='#4372c4', | |
# input_background_fill='#F1F2F4', | |
) | |
with gr.Blocks(theme=THEME, title='GenFBDD', css=static.CSS, delete_cache=(3600, 48 * 3600)) as demo: | |
with gr.Column(variant='panel'): | |
with gr.Tabs() as tabs: | |
with gr.Tab(label='Start', id='start'): | |
gr.Markdown(''' | |
# GenFBDD - A Fragment-Based Drug Design Protocol Based on SOTA Molecular Generative Models | |
Given a fragment library and a target protein, GenFBDD blindly docks the fragments to the | |
protein and generates linkers connecting the selected fragments, generating novel scaffolds | |
or drug-like molecules with desirable binding conformations. | |
''') | |
with gr.Row(): | |
with gr.Column(variant='panel'): | |
gr.Markdown('## Chemical Fragment Library') | |
# Fragment settings | |
frag_lib_dropdown = gr.Dropdown( | |
label='Select a Preset Fragment Library', | |
choices=list(FRAG_LIBS.keys()), | |
value='', | |
) | |
frag_lib_upload_btn = gr.UploadButton( | |
label='OR Upload Your Own Library', variant='primary', interactive=True, | |
) | |
frag_lib_file = gr.File( | |
value=None, label='Fragment Library File (Original)', | |
file_count='single', file_types=['.sdf', '.csv'], | |
interactive=False, visible=False | |
) | |
frag_lib_orig_df = gr.State(value=pd.DataFrame(columns=['X1', 'ID1', 'mol'])) | |
frag_lib_mod_df = gr.State(value=pd.DataFrame(columns=['X1', 'ID1', 'mol'])) | |
# frag_lib_view = gr.DataFrame( | |
# value=pd.DataFrame(columns=['X1', 'ID1']), elem_id='frag_lib_view', | |
# visible=True, interactive=False, | |
# ) | |
frag_lib_view = gr.HTML(static.IFRAME_TEMPLATE.format(aspect_ratio='1.618 /1', srcdoc='')) | |
with gr.Group(): | |
frag_lib_process_opts = gr.CheckboxGroup( | |
label='Fragment Preparation Options', | |
info='1) All fragments consisting of multiple fragments will be split into individual ' | |
'fragments. 2) All fragments consisting of a single heavy atom will be discarded. ' | |
'3) All fragments will then be processed in the order of the selected options. ' | |
'4) Finally, fragments will be deduplicated based on their SMILES.', | |
choices=list(FRAG_LIB_PROCESS_OPTS.keys()), | |
value=['Dehalogenate Fragments', 'Discard Inorganic Fragments'], | |
interactive=True, | |
) | |
frag_lib_process_btn = gr.Button( | |
value='Process Fragments', variant='primary', interactive=True, | |
) | |
# Fragment library preview | |
with gr.Column(variant='panel'): | |
gr.Markdown('## Target Protein Structure') | |
# Protein settings | |
with gr.Row(equal_height=True): | |
prot_query_dropdown = gr.Dropdown( | |
label='Select a Protein Structure Query Strategy', | |
choices=[ | |
'PDB ID', | |
'UniProt ID', | |
'FASTA Sequence', | |
], | |
interactive=True, | |
scale=4 | |
) | |
prot_query_input = gr.Textbox( | |
show_label=False, placeholder='Enter the protein query here', | |
scale=3, interactive=True | |
) | |
with gr.Row(): | |
prot_query_btn = gr.Button( | |
value='Query', variant='primary', | |
scale=1, interactive=True | |
) | |
prot_upload_btn = gr.UploadButton( | |
label='OR Upload Your PDB/FASTA File', variant='primary', | |
file_types=['.pdb', '.fasta'], | |
scale=2, interactive=True, | |
) | |
input_prot_file = gr.File( | |
value=None, label='Protein Structure File (Original)', | |
interactive=False, visible=False, file_count='single', | |
) | |
input_prot_view = gr.HTML(value='<div id="input_protein_view" class="mol-container"></div>') | |
with gr.Group(): | |
pocket_extract_dropdown = gr.Dropdown( | |
label='Select a Pocket Extraction Method', | |
choices=list(POCKET_EXTRACT_OPTS.keys()), | |
info=POCKET_EXTRACT_OPTS[list(POCKET_EXTRACT_OPTS.keys())[0]]['info'], | |
value=list(POCKET_EXTRACT_OPTS.keys())[0], | |
interactive=True, | |
) | |
selected_pocket = gr.Textbox(visible=False) | |
selected_ligand = gr.Textbox(visible=False) | |
pocket_files = gr.Files(visible=False) | |
pocket_extract_btn = gr.Button( | |
value='Extract Pocket', variant='primary', interactive=True | |
) | |
with gr.Row(): | |
with gr.Column(variant='panel'): | |
gr.Markdown('## Dock Phase Settings') | |
dock_n_poses = gr.Slider( | |
value=5, minimum=1, maximum=20, step=1, | |
label="Number of conformers to generate per fragment", | |
interactive=True | |
) | |
dock_confidence_cutoff = gr.Slider( | |
value=-1.0, minimum=-2.0, maximum=0, step=0.1, | |
label="Confidence cutoff for filtering conformers of docked fragments (>0: high, <=-1.5: low)", | |
interactive=True | |
) | |
with gr.Accordion(label='Advanced Options', open=False): | |
dock_model = gr.Dropdown( | |
label='Select a Fragment Docking Model', | |
choices=['DiffDock-L'], | |
interactive=True, | |
) | |
dock_steps = gr.Slider( | |
minimum=20, maximum=40, step=1, | |
label="Number of Denoising Steps for Docking Fragments", | |
interactive=True | |
) | |
with gr.Column(variant='panel'): | |
gr.Markdown('## Link Phase Settings') | |
link_frag_pose_strategy = gr.Radio( | |
label='Select a Fragment-Conformer Linking Strategy', | |
choices=[ | |
'Link Pairs of Fragment-Conformers Contacting the Pocket', | |
# 'Link Maximal Fragment-Conformers Spanning the Entire Pocket', | |
], | |
value='Link Pairs of Fragment-Conformers Contacting the Pocket', | |
) | |
link_frag_dist_range = RangeSlider( | |
value=[2, 8], minimum=1, maximum=10, step=1, | |
label="Fragment-Conformer Distance Range (Å) Eligible for Linking", | |
interactive=True | |
) | |
link_n_mols = gr.Slider( | |
value=10, minimum=1, maximum=20, step=1, | |
label="Number of molecules to generate per fragment conformer combination", | |
interactive=True | |
) | |
with gr.Accordion(label='Advanced Options', open=False): | |
link_model = gr.Dropdown( | |
label='Select a Linker Generation Model', | |
choices=['DiffLinker'], | |
interactive=True, | |
) | |
link_linker_size = gr.Slider( | |
minimum=0, maximum=20, step=1, | |
label="Linker Size", | |
info="0: automatically predicted; >=1: fixed size", | |
interactive=True | |
) | |
link_steps = gr.Slider( | |
minimum=100, maximum=500, step=10, | |
label="Number of Denoising Steps for Generating Linkers", | |
interactive=True | |
) | |
with gr.Row(equal_height=True): | |
email_input =gr.Textbox( | |
label='Email Address (Optional)', | |
info="Your email address will be used to notify you of the status of your job. " | |
"If you cannot receive the email, please check your spam/junk folder.", | |
type='email' | |
) | |
with gr.Column(): | |
start_clr_btn = gr.ClearButton( | |
value='Reset Inputs', interactive=True, | |
) | |
run_btn = gr.Button( | |
value='Run GenFBDD', variant='primary', interactive=True, | |
) | |
with gr.Tab(label='Jobs', id='job'): | |
gr.Markdown(''' | |
To check the status of an in-progress or historical job using the job ID and retrieve the predictions | |
if the job has completed. Note that predictions are only kept for 48 hours upon job completion. | |
You will be redirected to `Results` for carrying out further analysis and | |
generating the full report when the job is done. If the the query fails to respond, please wait for a | |
few minutes and refresh the page to try again. | |
''') | |
with gr.Row(): | |
with gr.Column(scale=1): | |
loader_html = gr.HTML('<div class="loader first-frame"></div>', visible=False) | |
with gr.Column(scale=4): | |
pred_lookup_id = gr.Textbox( | |
label='Input Your Job ID', placeholder='e.g., e9dfd149-3f5c-48a6-b797-c27d027611ac', | |
info="Your job ID is a UUID4 string that you receive after submitting a job on the " | |
"page or in the email notification.") | |
pred_lookup_btn = gr.Button(value='Query Status', variant='primary', visible=True) | |
pred_lookup_stop_btn = gr.Button(value='Stop Tracking', variant='stop', visible=False) | |
pred_lookup_status = gr.Markdown("**Job Status**", container=True) | |
with gr.Tab(label='Results', id='result'): | |
# Results | |
result_state = gr.State(value={}) | |
result_table_orig_df = gr.State(value=pd.DataFrame()) | |
result_table_mod_df = gr.State(value=pd.DataFrame()) | |
result_protein_file = gr.File(visible=False, interactive=False) | |
with gr.Column(variant='panel'): | |
with gr.Row(): | |
scores = gr.CheckboxGroup(list(fn.SCORE_MAP.keys()), label='Compound Scores') | |
filters = gr.CheckboxGroup(list(fn.FILTER_MAP.keys()), label='Compound Filters') | |
result_example = gr.Button('Example', elem_classes=['example']) | |
with gr.Row(): | |
prop_clr_btn = gr.ClearButton(value='Clear Properties', interactive=False) | |
prop_calc_btn = gr.Button(value='Calculate Properties', interactive=False, variant='primary') | |
with gr.Row(): | |
result_table_view = gr.HTML('<div id="result_view" class="fancy-table"></div>') | |
with gr.Column(): | |
result_prot_view = gr.HTML('<div id="result_protein_view" class="mol-container"></div>') | |
result_file_btn = gr.Button(value='Create Result File', visible=False, variant='primary') | |
result_download_file = gr.File(label='Download Result File', visible=False) | |
# Event handlers | |
## Start tab | |
### Fragment Library | |
frag_lib_dropdown_change = frag_lib_dropdown.change( | |
fn=lambda lib: gr.File(FRAG_LIBS[lib], visible=bool(lib)), | |
inputs=[frag_lib_dropdown], | |
outputs=[frag_lib_file], | |
) | |
frag_lib_upload_btn.upload( | |
fn=lambda file: gr.File(str(Path(file)), visible=True), | |
inputs=[frag_lib_upload_btn], | |
outputs=[frag_lib_file], | |
) | |
# Changing the file updates the original df, the modified df, and the view | |
frag_lib_file.change( | |
fn=gr_error_wrapper(read_fragment_library), | |
inputs=[frag_lib_file], | |
outputs=[frag_lib_orig_df], | |
).success( | |
fn=lambda df: [df.copy(), fn.create_result_table_html(fn.prepare_df_for_table(df))], | |
inputs=[frag_lib_orig_df], | |
outputs=[frag_lib_mod_df, frag_lib_view], | |
) | |
# Processing the fragment library updates the modified df | |
frag_lib_process_btn.click( | |
fn=lambda: gr.Info('Processing fragment library...'), | |
).then( | |
fn=lambda df, opts: [ | |
new_df:=process_fragment_library( | |
df, **checkbox_group_selections_to_kwargs(opts, FRAG_LIB_PROCESS_OPTS) | |
), | |
fn.create_result_table_html(fn.prepare_df_for_table(new_df)) | |
], | |
inputs=[frag_lib_orig_df, frag_lib_process_opts], | |
outputs=[frag_lib_mod_df, frag_lib_view], | |
) | |
def preprocess_protein_file(file): | |
filepath = Path(file.name) | |
if filepath.suffix == '.pdb': | |
return { | |
input_prot_file: gr.File(str(filepath), visible=True), | |
} | |
elif filepath.suffix == '.fasta': | |
seq = next(SeqIO.parse(file, 'fasta')).seq | |
filepath = pdb_query(seq, method='FASTA Sequence') | |
return { | |
input_prot_file: gr.File(str(filepath), visible=True), | |
prot_query_input: seq, | |
prot_query_dropdown: 'FASTA Sequence', | |
} | |
### Protein Structure | |
# prot_upload_btn.upload( | |
# fn=lambda file: gr.File(str(Path(file)), visible=True), | |
# inputs=[prot_upload_btn], | |
# outputs=[prot_file], | |
# ) | |
# prot_file.change( | |
# fn=lambda file: gr.HTML(fn.create_complex_view_html(file), visible=True), | |
# inputs=[prot_file], | |
# outputs=[input_prot_view], | |
# ) | |
prot_upload_btn.upload( | |
fn=preprocess_protein_file, | |
inputs=[prot_upload_btn], | |
outputs=[input_prot_file, prot_query_dropdown, prot_query_input], | |
) | |
def pdb_query(query, method): | |
"""Downloads protein structure data or searches FASTA sequence.""" | |
gr.Info(f'Querying protein by {method}...') | |
try: | |
if method == 'PDB ID': | |
url = f"https://files.rcsb.org/download/{query}.pdb" | |
file = fn.download_file(url) | |
elif method == 'UniProt ID': | |
pdb_ids = fn.uniprot_to_pdb(query) | |
if pdb_ids: | |
# Download the first associated PDB file | |
file = fn.download_file(f"https://files.rcsb.org/download/{pdb_ids[0]}.pdb") | |
else: | |
raise ValueError(f"No PDB IDs found for UniProt ID: {query}") | |
elif method == 'FASTA Sequence': | |
pdb_ids = fn.fasta_to_pdb(query) | |
if pdb_ids: | |
# Download the first associated PDB file | |
file = fn.download_file(f"https://files.rcsb.org/download/{pdb_ids[0]}.pdb") | |
else: | |
raise ValueError("No PDB IDs found for the provided FASTA sequence.") | |
else: | |
raise ValueError(f"Unsupported method: {method}") | |
return {input_prot_file: gr.File(str(file), visible=True)} | |
except Exception as e: | |
gr.Warning(f"Query error: {str(e)}") | |
prot_query_btn.click( | |
fn=pdb_query, | |
inputs=[prot_query_input, prot_query_dropdown], | |
outputs=[input_prot_file], | |
) | |
input_prot_file.change( | |
fn=lambda: gr.Info('Rendering 3DMol view...'), | |
).then( | |
fn=lambda x, y: gr.Info('3DMol view rendered.'), | |
inputs=[input_prot_file, input_prot_view], | |
js=static.CREATE_INPUT_MOL_VIEW, | |
) | |
#### Pocket Extraction | |
pocket_extract_dropdown.select( | |
fn=lambda method: gr.Button(visible=False) if POCKET_EXTRACT_OPTS[method] == 'clustering' | |
else gr.Button(visible=True), | |
inputs=[pocket_extract_dropdown], | |
outputs=[pocket_extract_btn], | |
) | |
# pocket_extract_btn.click( | |
# fn=lambda: gr.Info('Extracting pocket...'), | |
# ).then( | |
# fn=fn.extract_pockets_and_update_view, | |
# js=static.RETURN_LIGAND_SELECTION_JS, | |
# inputs=[prot_file, selected_ligand], | |
# outputs=[input_prot_view, pocket_path_dict, selected_ligand, selected_pocket], | |
# ) | |
pocket_extract_btn.click( | |
fn=lambda: gr.Info('Extracting pocket...') | |
).success( | |
fn=lambda x, y: [x, y], | |
js=static.RETURN_SELECTION, | |
inputs=[selected_ligand, selected_pocket], | |
outputs=[selected_ligand, selected_pocket], | |
).then( | |
fn=lambda prot, lig: [list(extract_pockets(prot, lig).values()), '', ''], | |
inputs=[input_prot_file, selected_ligand], | |
outputs=[pocket_files, selected_ligand, selected_pocket], | |
).success( | |
fn=lambda x, y: gr.Info('Pocket extraction completed.'), | |
js=static.UPDATE_MOL_VIEW, | |
inputs=[pocket_files, input_prot_view], | |
) | |
### Dock-Link Pipeline | |
job_valid = run_btn.click( | |
fn=lambda x, y: [x, y], | |
js=static.RETURN_SELECTION, | |
inputs=[selected_ligand, selected_pocket], | |
outputs=[selected_ligand, selected_pocket], | |
).success( | |
fn=job_submit, | |
inputs=[ | |
frag_lib_mod_df, frag_lib_file, input_prot_file, | |
dock_steps, dock_n_poses, dock_confidence_cutoff, | |
link_frag_dist_range, link_frag_pose_strategy, link_n_mols, | |
link_linker_size, link_steps, | |
selected_pocket, pocket_extract_dropdown, pocket_files, | |
email_input, | |
], | |
outputs=[pred_lookup_id, tabs], | |
) | |
start_reset_components=[ | |
frag_lib_dropdown, frag_lib_process_opts, | |
prot_query_dropdown, prot_query_input, input_prot_file, | |
dock_n_poses, dock_confidence_cutoff, dock_model, dock_steps, | |
link_frag_pose_strategy, link_n_mols, link_frag_dist_range, link_model, link_linker_size, link_steps, | |
email_input | |
] | |
def reset_components(components): | |
return [ | |
type(component)( | |
value=component.value, | |
visible=component.visible, | |
) for component in components | |
] | |
start_clr_btn.click( | |
fn=lambda: reset_components(start_reset_components), | |
outputs=start_reset_components, | |
show_progress='hidden', | |
) | |
### Job Status | |
user_job_lookup = pred_lookup_btn.click( | |
fn=lambda: [ | |
gr.update(value="Start querying the job database..."), | |
gr.update(visible=False), | |
gr.update(visible=True), | |
], | |
outputs=[pred_lookup_status, pred_lookup_btn, pred_lookup_stop_btn], | |
).success( | |
fn=query_job_status, | |
inputs=[pred_lookup_id], | |
outputs=[pred_lookup_status, tabs, result_state, pred_lookup_btn, pred_lookup_stop_btn], | |
show_progress='minimal', | |
) | |
auto_job_lookup = job_valid.success( | |
fn=lambda: [ | |
gr.update(value="Start querying the job database..."), | |
gr.update(visible=False), | |
gr.update(visible=True), | |
], | |
outputs=[pred_lookup_status, pred_lookup_btn, pred_lookup_stop_btn], | |
).success( | |
fn=query_job_status, | |
inputs=pred_lookup_id, | |
outputs=[pred_lookup_status, tabs, result_state, pred_lookup_btn, pred_lookup_stop_btn], | |
show_progress='minimal', | |
cancels=[user_job_lookup], | |
) | |
pred_lookup_stop_btn.click( | |
fn=lambda: [gr.Button(visible=True), gr.Button(visible=False)], | |
outputs=[pred_lookup_btn, pred_lookup_stop_btn], | |
cancels=[user_job_lookup, auto_job_lookup], | |
) | |
result_example.click( | |
fn=lambda: '80cf2658-7a1c-48d6-8372-61b978177fe6', | |
outputs=[pred_lookup_id], | |
show_progress='hidden' | |
).success( | |
fn=query_job_status, | |
inputs=pred_lookup_id, | |
outputs=[pred_lookup_status, tabs, result_state, pred_lookup_btn, pred_lookup_stop_btn], | |
show_progress='minimal', | |
cancels=[user_job_lookup, auto_job_lookup], | |
) | |
### Results | |
def update_results(result_info): | |
result_dir = Path(result_info['output_dir']) | |
result_type = result_info['type'] | |
protein_structure_file = Path(result_info['protein_structure_file']) | |
if result_type == 'docking': | |
result_df = pd.read_csv(result_dir / 'docking_summary.csv', dtype=fn.COL_DTYPE) | |
result_df['mol'] = result_df['X1'].apply(Chem.MolFromSmiles) | |
elif result_type == 'linking': | |
result_df = pd.read_csv(result_dir / 'linking_summary.csv', dtype=fn.COL_DTYPE) | |
result_df['mol'] = result_df['X1^'].apply(Chem.MolFromSmiles) | |
result_df = result_df[~result_df['X1^'].str.contains('.', regex=False)] | |
else: | |
raise gr.Error('Invalid result type') | |
result_df = fn.prepare_df_for_table(result_df) | |
return { | |
result_table_orig_df: result_df, | |
result_table_mod_df: result_df.copy(deep=True), | |
result_protein_file: str(protein_structure_file), | |
result_download_file: gr.File(None, visible=False), | |
} | |
def update_table(orig_df, score_list, filter_list, progress=gr.Progress(track_tqdm=True)): | |
gr.Info('Calculating properties...') | |
mod_df = orig_df.copy() | |
try: | |
for filter_name in filter_list: | |
mod_df[filter_name] = mod_df['mol'].apply( | |
lambda x: fn.FILTER_MAP[filter_name](x) if not pd.isna(x) else x) | |
for score_name in score_list: | |
mod_df[score_name] = mod_df['mol'].apply( | |
lambda x: fn.SCORE_MAP[score_name](x) if not pd.isna(x) else x) | |
except Exception as e: | |
gr.Warning(f'Failed to calculate properties due to error: {str(e)}') | |
finally: | |
return {result_table_mod_df: mod_df} | |
result_state.change( | |
fn=update_results, | |
inputs=[result_state], | |
outputs=[result_table_orig_df, result_table_mod_df, result_protein_file, result_download_file], | |
) | |
result_protein_file.change( | |
fn=lambda x, y: gr.Info('Rendering result table and 3DMol view...'), | |
js=static.CREATE_OUTPUT_MOL_VIEW, | |
inputs=[result_protein_file, result_prot_view], | |
) | |
result_table_mod_df.change( | |
fn=fn.create_result_table_html, | |
inputs=[result_table_mod_df, result_state], | |
outputs=[result_table_view] | |
).success( | |
fn=lambda x: [gr.Button(visible=True), gr.Button(interactive=True), gr.Button(interactive=True)], | |
inputs=[result_file_btn], | |
outputs=[result_file_btn, prop_calc_btn, prop_clr_btn], | |
) | |
prop_calc_btn.click( | |
fn=update_table, | |
inputs=[result_table_orig_df, scores, filters], | |
outputs=[result_table_mod_df], | |
) | |
prop_clr_btn.click( | |
fn=lambda orig_df: [orig_df, [], [], gr.File(visible=False)], | |
inputs=[result_table_orig_df], | |
outputs=[result_table_mod_df, scores, filters, result_download_file], | |
) | |
def generate_result_zip(result_info, compound_mod_df, protein_file): | |
result_path = Path(result_info['output_dir']) | |
zip_filename = f'GenFBDD_{result_path.name}.zip' | |
summary_filename = f'{result_info["type"]}_summary.csv' | |
zip_path = result_path / zip_filename | |
cols_to_drop = ['mol', 'Compound', 'protein_path'] | |
compound_mod_df.drop(columns=[col for col in cols_to_drop if col in compound_mod_df.columns], inplace=True) | |
compound_mod_df.rename(columns=fn.COL_ALIASES, inplace=True) | |
with zipfile.ZipFile(zip_path, 'w') as zip_file: | |
for file in result_path.rglob('*'): | |
# Skip directories, the zip file itself and the new summary file | |
if file.is_file() and file.name not in [zip_filename, summary_filename]: | |
archive_path = file.relative_to(result_path) | |
zip_file.write(file, arcname=archive_path) | |
if Path(protein_file).name not in zip_file.namelist(): | |
zip_file.write(Path(protein_file), arcname=Path(protein_file).name) | |
csv_buffer = io.BytesIO() | |
compound_mod_df.to_csv(csv_buffer, index=False) | |
zip_file.writestr(summary_filename, csv_buffer.getvalue()) | |
return gr.File(str(zip_path), visible=True) | |
result_file_btn.click( | |
fn=generate_result_zip, | |
inputs=[result_state, result_table_mod_df, result_protein_file], | |
outputs=[result_download_file], | |
) | |
demo.load(None, None, None, js=static.SETUP_JS) | |
demo.share_token = 'genfbdd' | |
demo.queue(default_concurrency_limit=None) | |
demo.launch( | |
server_name='0.0.0.0', | |
max_file_size="5mb", | |
ssr_mode=False, | |
show_api=False, | |
enable_monitoring=True, | |
) | |