# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os import subprocess import threading from pathlib import Path import numpy as np import torch def fasta_file_path(prefix_path): return prefix_path + ".fasta" class FastaDataset(torch.utils.data.Dataset): """ For loading protein sequence datasets in the common FASTA data format """ def __init__(self, path: str, cache_indices=False): self.fn = fasta_file_path(path) self.threadlocal = threading.local() self.cache = Path(f"{path}.fasta.idx.npy") if cache_indices: if self.cache.exists(): self.offsets, self.sizes = np.load(self.cache) else: self.offsets, self.sizes = self._build_index(path) np.save(self.cache, np.stack([self.offsets, self.sizes])) else: self.offsets, self.sizes = self._build_index(path) def _get_file(self): if not hasattr(self.threadlocal, "f"): self.threadlocal.f = open(self.fn, "r") return self.threadlocal.f def __getitem__(self, idx): f = self._get_file() f.seek(self.offsets[idx]) desc = f.readline().strip() line = f.readline() seq = "" while line != "" and line[0] != ">": seq += line.strip() line = f.readline() return desc, seq def __len__(self): return self.offsets.size def _build_index(self, path: str): # Use grep and awk to get 100M/s on local SSD. # Should process your enormous 100G fasta in ~10 min single core... path = fasta_file_path(path) bytes_offsets = subprocess.check_output( f"cat {path} | tqdm --bytes --total $(wc -c < {path})" "| grep --byte-offset '^>' -o | cut -d: -f1", shell=True, ) fasta_lengths = subprocess.check_output( f"cat {path} | tqdm --bytes --total $(wc -c < {path})" "| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'", shell=True, ) bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ") sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ") return bytes_np, sizes_np def __setstate__(self, state): self.__dict__ = state self.threadlocal = threading.local() def __getstate__(self): d = {} for i, v in self.__dict__.items(): if i != "threadlocal": d[i] = v return d def __del__(self): if hasattr(self.threadlocal, "f"): self.threadlocal.f.close() del self.threadlocal.f @staticmethod def exists(path): return os.path.exists(fasta_file_path(path)) class EncodedFastaDataset(FastaDataset): """ The FastaDataset returns raw sequences - this allows us to return indices with a dictionary instead. """ def __init__(self, path, dictionary): super().__init__(path, cache_indices=True) self.dictionary = dictionary def __getitem__(self, idx): desc, seq = super().__getitem__(idx) return self.dictionary.encode_line(seq, line_tokenizer=list).long()