|
|
|
|
|
|
|
|
|
|
|
|
|
from pathlib import Path |
|
import sys,os |
|
import argparse |
|
import logging |
|
import sys |
|
import typing as T |
|
from pathlib import Path |
|
from timeit import default_timer as timer |
|
|
|
import torch |
|
|
|
import esm |
|
from esm.data import read_fasta |
|
|
|
logger = logging.getLogger() |
|
logger.setLevel(logging.INFO) |
|
|
|
formatter = logging.Formatter( |
|
"%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
|
datefmt="%y/%m/%d %H:%M:%S", |
|
) |
|
|
|
console_handler = logging.StreamHandler(sys.stdout) |
|
console_handler.setLevel(logging.INFO) |
|
console_handler.setFormatter(formatter) |
|
logger.addHandler(console_handler) |
|
|
|
|
|
PathLike = T.Union[str, Path] |
|
|
|
|
|
def enable_cpu_offloading(model): |
|
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel |
|
from torch.distributed.fsdp.wrap import enable_wrap, wrap |
|
|
|
torch.distributed.init_process_group( |
|
backend="nccl", init_method="tcp://localhost:9999", world_size=1, rank=0 |
|
) |
|
|
|
wrapper_kwargs = dict(cpu_offload=CPUOffload(offload_params=True)) |
|
|
|
with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs): |
|
for layer_name, layer in model.layers.named_children(): |
|
wrapped_layer = wrap(layer) |
|
setattr(model.layers, layer_name, wrapped_layer) |
|
model = wrap(model) |
|
|
|
return model |
|
|
|
|
|
def init_model_on_gpu_with_cpu_offloading(model): |
|
model = model.eval() |
|
model_esm = enable_cpu_offloading(model.esm) |
|
del model.esm |
|
model.cuda() |
|
model.esm = model_esm |
|
return model |
|
|
|
|
|
def create_batched_sequence_datasest( |
|
sequences: T.List[T.Tuple[str, str]], max_tokens_per_batch: int = 1024 |
|
) -> T.Generator[T.Tuple[T.List[str], T.List[str]], None, None]: |
|
|
|
batch_headers, batch_sequences, num_tokens = [], [], 0 |
|
for header, seq in sequences: |
|
if (len(seq) + num_tokens > max_tokens_per_batch) and num_tokens > 0: |
|
yield batch_headers, batch_sequences |
|
batch_headers, batch_sequences, num_tokens = [], [], 0 |
|
batch_headers.append(header) |
|
batch_sequences.append(seq) |
|
num_tokens += len(seq) |
|
|
|
yield batch_headers, batch_sequences |
|
|
|
|
|
def create_parser(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-i", |
|
"--fasta", |
|
help="Path to input FASTA file", |
|
type=Path, |
|
required=True, |
|
) |
|
parser.add_argument( |
|
"-o", "--pdb", help="Path to output PDB directory", type=Path, required=True |
|
) |
|
parser.add_argument( |
|
"-m", "--model-dir", help="Parent path to Pretrained ESM data directory. ", type=Path, default=None |
|
) |
|
parser.add_argument( |
|
"--num-recycles", |
|
type=int, |
|
default=None, |
|
help="Number of recycles to run. Defaults to number used in training (4).", |
|
) |
|
parser.add_argument( |
|
"--max-tokens-per-batch", |
|
type=int, |
|
default=1024, |
|
help="Maximum number of tokens per gpu forward-pass. This will group shorter sequences together " |
|
"for batched prediction. Lowering this can help with out of memory issues, if these occur on " |
|
"short sequences.", |
|
) |
|
parser.add_argument( |
|
"--chunk-size", |
|
type=int, |
|
default=None, |
|
help="Chunks axial attention computation to reduce memory usage from O(L^2) to O(L). " |
|
"Equivalent to running a for loop over chunks of of each dimension. Lower values will " |
|
"result in lower memory usage at the cost of speed. Recommended values: 128, 64, 32. " |
|
"Default: None.", |
|
) |
|
parser.add_argument("--cpu-only", help="CPU only", action="store_true") |
|
parser.add_argument("--cpu-offload", help="Enable CPU offloading", action="store_true") |
|
return parser |
|
|
|
|
|
def run(args): |
|
if not args.fasta.exists(): |
|
raise FileNotFoundError(args.fasta) |
|
|
|
args.pdb.mkdir(exist_ok=True) |
|
|
|
|
|
logger.info(f"Reading sequences from {args.fasta}") |
|
all_sequences = sorted(read_fasta(args.fasta), key=lambda header_seq: len(header_seq[1])) |
|
logger.info(f"Loaded {len(all_sequences)} sequences from {args.fasta}") |
|
|
|
logger.info("Loading model") |
|
|
|
|
|
if args.model_dir is not None: |
|
|
|
torch.hub.set_dir(args.model_dir) |
|
|
|
model = esm.pretrained.esmfold_v1() |
|
|
|
|
|
model = model.eval() |
|
model.set_chunk_size(args.chunk_size) |
|
|
|
if args.cpu_only: |
|
model.esm.float() |
|
model.cpu() |
|
elif args.cpu_offload: |
|
model = init_model_on_gpu_with_cpu_offloading(model) |
|
else: |
|
model.cuda() |
|
logger.info("Starting Predictions") |
|
batched_sequences = create_batched_sequence_datasest(all_sequences, args.max_tokens_per_batch) |
|
|
|
num_completed = 0 |
|
num_sequences = len(all_sequences) |
|
for headers, sequences in batched_sequences: |
|
start = timer() |
|
try: |
|
output = model.infer(sequences, num_recycles=args.num_recycles) |
|
except RuntimeError as e: |
|
if e.args[0].startswith("CUDA out of memory"): |
|
if len(sequences) > 1: |
|
logger.info( |
|
f"Failed (CUDA out of memory) to predict batch of size {len(sequences)}. " |
|
"Try lowering `--max-tokens-per-batch`." |
|
) |
|
else: |
|
logger.info( |
|
f"Failed (CUDA out of memory) on sequence {headers[0]} of length {len(sequences[0])}." |
|
) |
|
|
|
continue |
|
raise |
|
|
|
output = {key: value.cpu() for key, value in output.items()} |
|
pdbs = model.output_to_pdb(output) |
|
tottime = timer() - start |
|
time_string = f"{tottime / len(headers):0.1f}s" |
|
if len(sequences) > 1: |
|
time_string = time_string + f" (amortized, batch size {len(sequences)})" |
|
for header, seq, pdb_string, mean_plddt, ptm in zip( |
|
headers, sequences, pdbs, output["mean_plddt"], output["ptm"] |
|
): |
|
output_file = args.pdb / f"{header}.pdb" |
|
output_file.write_text(pdb_string) |
|
num_completed += 1 |
|
logger.info( |
|
f"Predicted structure for {header} with length {len(seq)}, pLDDT {mean_plddt:0.1f}, " |
|
f"pTM {ptm:0.3f} in {time_string}. " |
|
f"{num_completed} / {num_sequences} completed." |
|
) |
|
|
|
|
|
def main(): |
|
parser = create_parser() |
|
args = parser.parse_args() |
|
run(args) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|