#!/usr/bin/env python3 -u # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse import pathlib import sys print("using", sys.executable) sys.path.insert( 0,"/home/user/.local/lib/python3.8/site-packages") sys.path.insert( 0,"/home/user/app/esm/") import os import torch from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer def create_parser(): parser = argparse.ArgumentParser( description="Extract per-token representations and model outputs for sequences in a FASTA file" # noqa ) parser.add_argument( "model_location", type=str, help="PyTorch model file OR name of pretrained model to download (see README for models)", ) parser.add_argument( "fasta_file", type=pathlib.Path, help="FASTA file on which to extract representations", ) parser.add_argument( "output_dir", type=pathlib.Path, help="output directory for extracted representations", ) parser.add_argument("--toks_per_batch", type=int, default=4096, help="maximum batch size") parser.add_argument( "--repr_layers", type=int, default=[-1], nargs="+", help="layers indices from which to extract representations (0 to num_layers, inclusive)", ) parser.add_argument( "--include", type=str, nargs="+", choices=["mean", "per_tok", "bos", "contacts"], help="specify which representations to return", required=True, ) parser.add_argument( "--truncation_seq_length", type=int, default=1022, help="truncate sequences longer than the given value", ) parser.add_argument("--nogpu", action="store_true", help="Do not use GPU even if available") return parser def main(args): model, alphabet = pretrained.load_model_and_alphabet(args.model_location) model.eval() if isinstance(model, MSATransformer): raise ValueError( "This script currently does not handle models with MSA input (MSA Transformer)." ) if torch.cuda.is_available() and not args.nogpu: model = model.cuda() print("Transferred model to GPU") dataset = FastaBatchedDataset.from_file(args.fasta_file) batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1) data_loader = torch.utils.data.DataLoader( dataset, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length), batch_sampler=batches ) print(f"Read {args.fasta_file} with {len(dataset)} sequences") args.output_dir.mkdir(parents=True, exist_ok=True) return_contacts = "contacts" in args.include assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in args.repr_layers) repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in args.repr_layers] with torch.no_grad(): for batch_idx, (labels, strs, toks) in enumerate(data_loader): print( f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)" ) if torch.cuda.is_available() and not args.nogpu: toks = toks.to(device="cuda", non_blocking=True) out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts) logits = out["logits"].to(device="cpu") representations = { layer: t.to(device="cpu") for layer, t in out["representations"].items() } if return_contacts: contacts = out["contacts"].to(device="cpu") for i, label in enumerate(labels): args.output_file = args.output_dir / f"{label}.pt" args.output_file.parent.mkdir(parents=True, exist_ok=True) result = {"label": label} # Call clone on tensors to ensure tensors are not views into a larger representation # See https://github.com/pytorch/pytorch/issues/1995 if "per_tok" in args.include: result["representations"] = { layer: t[i, 1 : len(strs[i]) + 1].clone() for layer, t in representations.items() } if "mean" in args.include: result["mean_representations"] = { layer: t[i, 1 : len(strs[i]) + 1].mean(0).clone() for layer, t in representations.items() } if "bos" in args.include: result["bos_representations"] = { layer: t[i, 0].clone() for layer, t in representations.items() } if return_contacts: result["contacts"] = contacts[i, : len(strs[i]), : len(strs[i])].clone() torch.save( result, args.output_file, ) if __name__ == "__main__": parser = create_parser() args = parser.parse_args() main(args)