Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
from gradio_molecule3d import Molecule3D | |
import spaces | |
import subprocess | |
import glob | |
# directory to store cached outputs | |
CACHE_DIR = "gradio_cached_examples" | |
reps = [ | |
{ | |
"model": 0, | |
"chain": "", | |
"resname": "", | |
"style": "stick", | |
"color": "whiteCarbon", | |
"residue_range": "", | |
"around": 0, | |
"byres": False, | |
"visible": False | |
} | |
] | |
# Ensure the cache directory exists | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
# Define example files and precomputed outputs | |
example_fasta_files = [ | |
f"cache_examples/boltz_0.fasta", | |
f"cache_examples/Armadillo_6.fasta", | |
f"cache_examples/Covid_3.fasta", | |
f"cache_examples/Malaria_2.fasta", | |
f"cache_examples/MITOCHONDRIAL_9.fasta", | |
f"cache_examples/Monkeypox_4.fasta", | |
f"cache_examples/Plasmodium_1.fasta", | |
f"cache_examples/PROTOCADHERIN_8.fasta", | |
f"cache_examples/Vault_5.fasta", | |
f"cache_examples/Zipper_7.fasta", | |
] | |
# matching `.pdb` files in the `CACHE_DIR` | |
example_outputs = [ | |
os.path.join(CACHE_DIR, os.path.basename(fasta_file).replace(".fasta", ".pdb")) | |
for fasta_file in example_fasta_files | |
] | |
# must load cached outputs | |
def load_cached_example_outputs(fasta_file: str) -> str: | |
# Find the corresponding `.pdb` file | |
pdb_file = os.path.basename(fasta_file).replace(".fasta", ".pdb") | |
cached_pdb_path = os.path.join(CACHE_DIR, pdb_file) | |
if os.path.exists(cached_pdb_path): | |
return cached_pdb_path | |
else: | |
raise FileNotFoundError(f"Cached output not found for {pdb_file}") | |
# handle example click | |
def on_example_click(fasta_file: str) -> str: | |
return load_cached_example_outputs(fasta_file) | |
# run predictions | |
def predict(data, | |
accelerator="gpu", sampling_steps=50, | |
diffusion_samples=1): | |
print("Arguments passed to `predict` function:") | |
print(f" data: {data}") | |
print(f" accelerator: {accelerator}") | |
print(f" sampling_steps: {sampling_steps}") | |
print(f" diffusion_samples: {diffusion_samples}") | |
# we construct the base command | |
command = [ | |
"boltz", "predict", | |
"--out_dir", "./", | |
"--accelerator", accelerator, | |
"--sampling_steps", str(sampling_steps), | |
"--diffusion_samples", str(diffusion_samples), | |
"--output_format", "pdb", | |
] | |
command.extend(["--checkpoint", "./ckpt/boltz1.ckpt"]) | |
command.append(data) | |
result = subprocess.run(command, capture_output=True, text=True) | |
if result.returncode == 0: | |
print("Prediction completed successfully...!") | |
print(f"Output saved to: {out_dir}") | |
else: | |
print("Prediction failed :(") | |
print("Error:", result.stderr) | |
def run_prediction(input_file, accelerator, sampling_steps, | |
diffusion_samples): | |
data = input_file.name | |
print("the data : ", data) | |
predict( | |
data=data, | |
accelerator=accelerator, | |
sampling_steps=sampling_steps, | |
diffusion_samples=diffusion_samples | |
) | |
# search for the latest .pdb file in the predictions folder | |
out_dir = "./" | |
search_path = os.path.join(out_dir, "boltz_results*/predictions/**/*.pdb") | |
pdb_files = glob.glob(search_path, recursive=True) | |
if not pdb_files: | |
print("No .pdb files found in the predictions folder.") | |
return None | |
# some manual logic | |
# get the latest .pdb file based on modification time | |
latest_pdb_file = max(pdb_files, key=os.path.getmtime) | |
return latest_pdb_file | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🔬 Boltz-1: Democratizing Biomolecular Interaction Modeling 🧬") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
inp = gr.File(label="Upload a .fasta File", file_types=[".fasta"]) | |
with gr.Accordion("Advanced Settings", open=False): | |
accelerator = gr.Radio(choices=["gpu", "cpu"], value="gpu", label="Accelerator") | |
sampling_steps = gr.Slider(minimum=1, maximum=500, value=50, step=1, label="Sampling Steps") | |
diffusion_samples = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Diffusion Samples") | |
btn = gr.Button("Predict") | |
with gr.Column(scale=3): | |
out = Molecule3D(label="Generated Molecule", reps=reps) | |
btn.click( | |
run_prediction, | |
inputs=[inp, accelerator, sampling_steps, diffusion_samples], | |
outputs=out | |
) | |
gr.Examples( | |
examples=[[fasta_file] for fasta_file in example_fasta_files], | |
inputs=[inp], | |
outputs=out, | |
fn=lambda fasta_file: on_example_click(fasta_file), | |
cache_examples=True | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True, debug=True) | |