Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 -u | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import pathlib | |
| import sys | |
| print("using", sys.executable) | |
| sys.path.insert( 0,"/home/user/.local/lib/python3.8/site-packages") | |
| sys.path.insert( 0,"/home/user/app/esm/") | |
| import os | |
| import torch | |
| from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer | |
| def create_parser(): | |
| parser = argparse.ArgumentParser( | |
| description="Extract per-token representations and model outputs for sequences in a FASTA file" # noqa | |
| ) | |
| parser.add_argument( | |
| "model_location", | |
| type=str, | |
| help="PyTorch model file OR name of pretrained model to download (see README for models)", | |
| ) | |
| parser.add_argument( | |
| "fasta_file", | |
| type=pathlib.Path, | |
| help="FASTA file on which to extract representations", | |
| ) | |
| parser.add_argument( | |
| "output_dir", | |
| type=pathlib.Path, | |
| help="output directory for extracted representations", | |
| ) | |
| parser.add_argument("--toks_per_batch", type=int, default=4096, help="maximum batch size") | |
| parser.add_argument( | |
| "--repr_layers", | |
| type=int, | |
| default=[-1], | |
| nargs="+", | |
| help="layers indices from which to extract representations (0 to num_layers, inclusive)", | |
| ) | |
| parser.add_argument( | |
| "--include", | |
| type=str, | |
| nargs="+", | |
| choices=["mean", "per_tok", "bos", "contacts"], | |
| help="specify which representations to return", | |
| required=True, | |
| ) | |
| parser.add_argument( | |
| "--truncation_seq_length", | |
| type=int, | |
| default=1022, | |
| help="truncate sequences longer than the given value", | |
| ) | |
| parser.add_argument("--nogpu", action="store_true", help="Do not use GPU even if available") | |
| return parser | |
| def main(args): | |
| model, alphabet = pretrained.load_model_and_alphabet(args.model_location) | |
| model.eval() | |
| if isinstance(model, MSATransformer): | |
| raise ValueError( | |
| "This script currently does not handle models with MSA input (MSA Transformer)." | |
| ) | |
| if torch.cuda.is_available() and not args.nogpu: | |
| model = model.cuda() | |
| print("Transferred model to GPU") | |
| dataset = FastaBatchedDataset.from_file(args.fasta_file) | |
| batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1) | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length), batch_sampler=batches | |
| ) | |
| print(f"Read {args.fasta_file} with {len(dataset)} sequences") | |
| args.output_dir.mkdir(parents=True, exist_ok=True) | |
| return_contacts = "contacts" in args.include | |
| assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in args.repr_layers) | |
| repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in args.repr_layers] | |
| with torch.no_grad(): | |
| for batch_idx, (labels, strs, toks) in enumerate(data_loader): | |
| print( | |
| f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)" | |
| ) | |
| if torch.cuda.is_available() and not args.nogpu: | |
| toks = toks.to(device="cuda", non_blocking=True) | |
| out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts) | |
| logits = out["logits"].to(device="cpu") | |
| representations = { | |
| layer: t.to(device="cpu") for layer, t in out["representations"].items() | |
| } | |
| if return_contacts: | |
| contacts = out["contacts"].to(device="cpu") | |
| for i, label in enumerate(labels): | |
| args.output_file = args.output_dir / f"{label}.pt" | |
| args.output_file.parent.mkdir(parents=True, exist_ok=True) | |
| result = {"label": label} | |
| # Call clone on tensors to ensure tensors are not views into a larger representation | |
| # See https://github.com/pytorch/pytorch/issues/1995 | |
| if "per_tok" in args.include: | |
| result["representations"] = { | |
| layer: t[i, 1 : len(strs[i]) + 1].clone() | |
| for layer, t in representations.items() | |
| } | |
| if "mean" in args.include: | |
| result["mean_representations"] = { | |
| layer: t[i, 1 : len(strs[i]) + 1].mean(0).clone() | |
| for layer, t in representations.items() | |
| } | |
| if "bos" in args.include: | |
| result["bos_representations"] = { | |
| layer: t[i, 0].clone() for layer, t in representations.items() | |
| } | |
| if return_contacts: | |
| result["contacts"] = contacts[i, : len(strs[i]), : len(strs[i])].clone() | |
| torch.save( | |
| result, | |
| args.output_file, | |
| ) | |
| if __name__ == "__main__": | |
| parser = create_parser() | |
| args = parser.parse_args() | |
| main(args) | |