wenkai's picture
Upload 344 files
e740833 verified
raw
history blame
6.75 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
import sys,os
import argparse
import logging
import sys
import typing as T
from pathlib import Path
from timeit import default_timer as timer
import torch
import esm
from esm.data import read_fasta
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%y/%m/%d %H:%M:%S",
)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
PathLike = T.Union[str, Path]
def enable_cpu_offloading(model):
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel
from torch.distributed.fsdp.wrap import enable_wrap, wrap
torch.distributed.init_process_group(
backend="nccl", init_method="tcp://localhost:9999", world_size=1, rank=0
)
wrapper_kwargs = dict(cpu_offload=CPUOffload(offload_params=True))
with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs):
for layer_name, layer in model.layers.named_children():
wrapped_layer = wrap(layer)
setattr(model.layers, layer_name, wrapped_layer)
model = wrap(model)
return model
def init_model_on_gpu_with_cpu_offloading(model):
model = model.eval()
model_esm = enable_cpu_offloading(model.esm)
del model.esm
model.cuda()
model.esm = model_esm
return model
def create_batched_sequence_datasest(
sequences: T.List[T.Tuple[str, str]], max_tokens_per_batch: int = 1024
) -> T.Generator[T.Tuple[T.List[str], T.List[str]], None, None]:
batch_headers, batch_sequences, num_tokens = [], [], 0
for header, seq in sequences:
if (len(seq) + num_tokens > max_tokens_per_batch) and num_tokens > 0:
yield batch_headers, batch_sequences
batch_headers, batch_sequences, num_tokens = [], [], 0
batch_headers.append(header)
batch_sequences.append(seq)
num_tokens += len(seq)
yield batch_headers, batch_sequences
def create_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--fasta",
help="Path to input FASTA file",
type=Path,
required=True,
)
parser.add_argument(
"-o", "--pdb", help="Path to output PDB directory", type=Path, required=True
)
parser.add_argument(
"-m", "--model-dir", help="Parent path to Pretrained ESM data directory. ", type=Path, default=None
)
parser.add_argument(
"--num-recycles",
type=int,
default=None,
help="Number of recycles to run. Defaults to number used in training (4).",
)
parser.add_argument(
"--max-tokens-per-batch",
type=int,
default=1024,
help="Maximum number of tokens per gpu forward-pass. This will group shorter sequences together "
"for batched prediction. Lowering this can help with out of memory issues, if these occur on "
"short sequences.",
)
parser.add_argument(
"--chunk-size",
type=int,
default=None,
help="Chunks axial attention computation to reduce memory usage from O(L^2) to O(L). "
"Equivalent to running a for loop over chunks of of each dimension. Lower values will "
"result in lower memory usage at the cost of speed. Recommended values: 128, 64, 32. "
"Default: None.",
)
parser.add_argument("--cpu-only", help="CPU only", action="store_true")
parser.add_argument("--cpu-offload", help="Enable CPU offloading", action="store_true")
return parser
def run(args):
if not args.fasta.exists():
raise FileNotFoundError(args.fasta)
args.pdb.mkdir(exist_ok=True)
# Read fasta and sort sequences by length
logger.info(f"Reading sequences from {args.fasta}")
all_sequences = sorted(read_fasta(args.fasta), key=lambda header_seq: len(header_seq[1]))
logger.info(f"Loaded {len(all_sequences)} sequences from {args.fasta}")
logger.info("Loading model")
# Use pre-downloaded ESM weights from model_pth.
if args.model_dir is not None:
# if pretrained model path is available
torch.hub.set_dir(args.model_dir)
model = esm.pretrained.esmfold_v1()
model = model.eval()
model.set_chunk_size(args.chunk_size)
if args.cpu_only:
model.esm.float() # convert to fp32 as ESM-2 in fp16 is not supported on CPU
model.cpu()
elif args.cpu_offload:
model = init_model_on_gpu_with_cpu_offloading(model)
else:
model.cuda()
logger.info("Starting Predictions")
batched_sequences = create_batched_sequence_datasest(all_sequences, args.max_tokens_per_batch)
num_completed = 0
num_sequences = len(all_sequences)
for headers, sequences in batched_sequences:
start = timer()
try:
output = model.infer(sequences, num_recycles=args.num_recycles)
except RuntimeError as e:
if e.args[0].startswith("CUDA out of memory"):
if len(sequences) > 1:
logger.info(
f"Failed (CUDA out of memory) to predict batch of size {len(sequences)}. "
"Try lowering `--max-tokens-per-batch`."
)
else:
logger.info(
f"Failed (CUDA out of memory) on sequence {headers[0]} of length {len(sequences[0])}."
)
continue
raise
output = {key: value.cpu() for key, value in output.items()}
pdbs = model.output_to_pdb(output)
tottime = timer() - start
time_string = f"{tottime / len(headers):0.1f}s"
if len(sequences) > 1:
time_string = time_string + f" (amortized, batch size {len(sequences)})"
for header, seq, pdb_string, mean_plddt, ptm in zip(
headers, sequences, pdbs, output["mean_plddt"], output["ptm"]
):
output_file = args.pdb / f"{header}.pdb"
output_file.write_text(pdb_string)
num_completed += 1
logger.info(
f"Predicted structure for {header} with length {len(seq)}, pLDDT {mean_plddt:0.1f}, "
f"pTM {ptm:0.3f} in {time_string}. "
f"{num_completed} / {num_sequences} completed."
)
def main():
parser = create_parser()
args = parser.parse_args()
run(args)
if __name__ == "__main__":
main()