File size: 5,945 Bytes
44459bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
"""Boltz-1 folding submission command."""
import json
from datetime import datetime
from pathlib import Path
from typing import Annotated, Any, Optional
import typer
from rich.json import JSON
from rich.panel import Panel
from folding_studio.client import Client
from folding_studio.commands.utils import (
success_fail_catch_print,
success_fail_catch_spinner,
)
from folding_studio.config import FOLDING_API_KEY
from folding_studio.console import console
from folding_studio.query.boltz import BoltzQuery
def boltz(
source: Annotated[
Path,
typer.Argument(
help=(
"Path to the data source. Either a FASTA file, a YAML file, "
"or a directory containing FASTA and YAML files."
),
exists=True,
),
],
project_code: Annotated[
str,
typer.Option(
help="Project code. If unknown, contact your PM or the Folding Studio team.",
envvar="FOLDING_PROJECT_CODE",
# exists=True,
),
],
parameters_json: Annotated[
Path | None,
typer.Option(help="Path to JSON file containing Boltz inference parameters."),
] = None,
recycling_steps: Annotated[
int, typer.Option(help="Number of recycling steps for prediction.")
] = 3,
sampling_steps: Annotated[
int, typer.Option(help="Number of sampling steps for prediction.")
] = 200,
diffusion_samples: Annotated[
int, typer.Option(help="Number of diffusion samples for prediction.")
] = 1,
step_scale: Annotated[
float,
typer.Option(
help="Step size related to the temperature at which the diffusion process samples the distribution."
),
] = 1.638,
msa_pairing_strategy: Annotated[
str, typer.Option(help="Pairing strategy for MSA generation.")
] = "greedy",
write_full_pae: Annotated[
bool, typer.Option(help="Whether to save the full PAE matrix as a file.")
] = False,
write_full_pde: Annotated[
bool, typer.Option(help="Whether to save the full PDE matrix as a file.")
] = False,
use_msa_server: Annotated[
bool,
typer.Option(help="Flag to use the MSA server for inference.", is_flag=True),
] = True,
msa_path: Annotated[
Optional[str],
typer.Option(
help="Path to the custom MSAs. It can be a .a3m or .aligned.pqt file, or a directory containing these files."
),
] = None,
seed: Annotated[
int | None, typer.Option(help="Seed for random number generation.")
] = 0,
output: Annotated[
Path,
typer.Option(
help="Local path to download the result zip and query parameters to. "
"Default to 'boltz_results'."
),
] = "boltz_results",
force: Annotated[
bool,
typer.Option(
help=(
"Forces the download to overwrite any existing file "
"with the same name in the specified location."
)
),
] = False,
unzip: Annotated[
bool, typer.Option(help="Unzip the file after its download.")
] = False,
spinner: Annotated[
bool, typer.Option(help="Use live spinner in log output.")
] = True,
):
"""Synchronous Boltz-1 folding submission."""
success_fail_catch = (
success_fail_catch_spinner if spinner else success_fail_catch_print
)
# If a custom MSA path is provided, disable automated MSA search.
if msa_path is not None:
console.print(
"\n[yellow]:warning: Custom MSA path provided. Disabling automated MSA search.[/yellow]"
)
use_msa_server = False
console.print(
Panel("[bold cyan]:dna: Boltz1 Folding submission [/bold cyan]", expand=False)
)
output_dir = output / f"submission_{datetime.now().strftime('%Y%m%d%H%M%S')}"
# Initialize parameters with CLI-provided values
parameters = {
"recycling_steps": recycling_steps,
"sampling_steps": sampling_steps,
"diffusion_samples": diffusion_samples,
"step_scale": step_scale,
"msa_pairing_strategy": msa_pairing_strategy,
"write_full_pae": write_full_pae,
"write_full_pde": write_full_pde,
"use_msa_server": use_msa_server,
"seed": seed,
"custom_msa_paths": msa_path,
}
if parameters_json:
try:
with open(parameters_json, "r") as f:
json_parameters: dict[str, Any] = json.load(f)
except Exception as e:
raise ValueError(f"Error reading JSON file: {e}")
console.print(
":warning: Parameters specified in the configuration file will "
"take precedence over the CLI options."
)
parameters.update(json_parameters)
# Create a client using API key or JWT
with success_fail_catch(":key: Authenticating client"):
client = Client.authenticate()
# Define query
with success_fail_catch(":package: Generating query"):
query_builder = (
BoltzQuery.from_file if source.is_file() else BoltzQuery.from_directory
)
query: BoltzQuery = query_builder(source, **parameters)
query.save_parameters(output_dir)
console.print("[blue]Generated query: [/blue]", end="")
console.print(JSON.from_data(query.payload), style="blue")
# Send a request
with success_fail_catch(":brain: Processing folding job"):
response = client.send_request(query, project_code)
# Access confidence data
console.print("[blue]Confidence data:[/blue]", end=" ")
console.print(JSON.from_data(response.confidence_data), style="blue")
with success_fail_catch(
f":floppy_disk: Downloading results to `[green]{output_dir}[/green]`"
):
response.download_results(output_dir=output_dir, force=force, unzip=unzip)
|