|
|
|
|
|
|
|
|
|
|
|
import itertools |
|
import os |
|
from typing import Sequence, Tuple, List, Union |
|
import pickle |
|
import re |
|
import shutil |
|
import torch |
|
from pathlib import Path |
|
from esm.constants import proteinseq_toks |
|
|
|
RawMSA = Sequence[Tuple[str, str]] |
|
|
|
|
|
class FastaBatchedDataset(object): |
|
def __init__(self, sequence_labels, sequence_strs): |
|
self.sequence_labels = list(sequence_labels) |
|
self.sequence_strs = list(sequence_strs) |
|
|
|
@classmethod |
|
def from_file(cls, fasta_file): |
|
sequence_labels, sequence_strs = [], [] |
|
cur_seq_label = None |
|
buf = [] |
|
|
|
def _flush_current_seq(): |
|
nonlocal cur_seq_label, buf |
|
if cur_seq_label is None: |
|
return |
|
sequence_labels.append(cur_seq_label) |
|
sequence_strs.append("".join(buf)) |
|
cur_seq_label = None |
|
buf = [] |
|
|
|
with open(fasta_file, "r") as infile: |
|
for line_idx, line in enumerate(infile): |
|
if line.startswith(">"): |
|
_flush_current_seq() |
|
line = line[1:].strip() |
|
if len(line) > 0: |
|
cur_seq_label = line |
|
else: |
|
cur_seq_label = f"seqnum{line_idx:09d}" |
|
else: |
|
buf.append(line.strip()) |
|
|
|
_flush_current_seq() |
|
|
|
assert len(set(sequence_labels)) == len( |
|
sequence_labels |
|
), "Found duplicate sequence labels" |
|
|
|
return cls(sequence_labels, sequence_strs) |
|
|
|
def __len__(self): |
|
return len(self.sequence_labels) |
|
|
|
def __getitem__(self, idx): |
|
return self.sequence_labels[idx], self.sequence_strs[idx] |
|
|
|
def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0): |
|
sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)] |
|
sizes.sort() |
|
batches = [] |
|
buf = [] |
|
max_len = 0 |
|
|
|
def _flush_current_buf(): |
|
nonlocal max_len, buf |
|
if len(buf) == 0: |
|
return |
|
batches.append(buf) |
|
buf = [] |
|
max_len = 0 |
|
|
|
for sz, i in sizes: |
|
sz += extra_toks_per_seq |
|
if max(sz, max_len) * (len(buf) + 1) > toks_per_batch: |
|
_flush_current_buf() |
|
max_len = max(max_len, sz) |
|
buf.append(i) |
|
|
|
_flush_current_buf() |
|
return batches |
|
|
|
|
|
class Alphabet(object): |
|
def __init__( |
|
self, |
|
standard_toks: Sequence[str], |
|
prepend_toks: Sequence[str] = ("<null_0>", "<pad>", "<eos>", "<unk>"), |
|
append_toks: Sequence[str] = ("<cls>", "<mask>", "<sep>"), |
|
prepend_bos: bool = True, |
|
append_eos: bool = False, |
|
use_msa: bool = False, |
|
): |
|
self.standard_toks = list(standard_toks) |
|
self.prepend_toks = list(prepend_toks) |
|
self.append_toks = list(append_toks) |
|
self.prepend_bos = prepend_bos |
|
self.append_eos = append_eos |
|
self.use_msa = use_msa |
|
|
|
self.all_toks = list(self.prepend_toks) |
|
self.all_toks.extend(self.standard_toks) |
|
for i in range((8 - (len(self.all_toks) % 8)) % 8): |
|
self.all_toks.append(f"<null_{i + 1}>") |
|
self.all_toks.extend(self.append_toks) |
|
|
|
self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)} |
|
|
|
self.unk_idx = self.tok_to_idx["<unk>"] |
|
self.padding_idx = self.get_idx("<pad>") |
|
self.cls_idx = self.get_idx("<cls>") |
|
self.mask_idx = self.get_idx("<mask>") |
|
self.eos_idx = self.get_idx("<eos>") |
|
self.all_special_tokens = ['<eos>', '<unk>', '<pad>', '<cls>', '<mask>'] |
|
self.unique_no_split_tokens = self.all_toks |
|
|
|
def __len__(self): |
|
return len(self.all_toks) |
|
|
|
def get_idx(self, tok): |
|
return self.tok_to_idx.get(tok, self.unk_idx) |
|
|
|
def get_tok(self, ind): |
|
return self.all_toks[ind] |
|
|
|
def to_dict(self): |
|
return self.tok_to_idx.copy() |
|
|
|
def get_batch_converter(self, truncation_seq_length: int = None): |
|
if self.use_msa: |
|
return MSABatchConverter(self, truncation_seq_length) |
|
else: |
|
return BatchConverter(self, truncation_seq_length) |
|
|
|
@classmethod |
|
def from_architecture(cls, name: str) -> "Alphabet": |
|
if name in ("ESM-1", "protein_bert_base"): |
|
standard_toks = proteinseq_toks["toks"] |
|
prepend_toks: Tuple[str, ...] = ("<null_0>", "<pad>", "<eos>", "<unk>") |
|
append_toks: Tuple[str, ...] = ("<cls>", "<mask>", "<sep>") |
|
prepend_bos = True |
|
append_eos = False |
|
use_msa = False |
|
elif name in ("ESM-1b", "roberta_large"): |
|
standard_toks = proteinseq_toks["toks"] |
|
prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>") |
|
append_toks = ("<mask>",) |
|
prepend_bos = True |
|
append_eos = True |
|
use_msa = False |
|
elif name in ("MSA Transformer", "msa_transformer"): |
|
standard_toks = proteinseq_toks["toks"] |
|
prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>") |
|
append_toks = ("<mask>",) |
|
prepend_bos = True |
|
append_eos = False |
|
use_msa = True |
|
elif "invariant_gvp" in name.lower(): |
|
standard_toks = proteinseq_toks["toks"] |
|
prepend_toks = ("<null_0>", "<pad>", "<eos>", "<unk>") |
|
append_toks = ("<mask>", "<cath>", "<af2>") |
|
prepend_bos = True |
|
append_eos = False |
|
use_msa = False |
|
else: |
|
raise ValueError("Unknown architecture selected") |
|
return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa) |
|
|
|
def _tokenize(self, text) -> str: |
|
return text.split() |
|
|
|
def tokenize(self, text, **kwargs) -> List[str]: |
|
""" |
|
Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py |
|
Converts a string in a sequence of tokens, using the tokenizer. |
|
|
|
Args: |
|
text (:obj:`str`): |
|
The sequence to be encoded. |
|
|
|
Returns: |
|
:obj:`List[str]`: The list of tokens. |
|
""" |
|
|
|
def split_on_token(tok, text): |
|
result = [] |
|
split_text = text.split(tok) |
|
for i, sub_text in enumerate(split_text): |
|
|
|
|
|
|
|
|
|
|
|
if i < len(split_text) - 1: |
|
sub_text = sub_text.rstrip() |
|
if i > 0: |
|
sub_text = sub_text.lstrip() |
|
|
|
if i == 0 and not sub_text: |
|
result.append(tok) |
|
elif i == len(split_text) - 1: |
|
if sub_text: |
|
result.append(sub_text) |
|
else: |
|
pass |
|
else: |
|
if sub_text: |
|
result.append(sub_text) |
|
result.append(tok) |
|
return result |
|
|
|
def split_on_tokens(tok_list, text): |
|
if not text.strip(): |
|
return [] |
|
|
|
tokenized_text = [] |
|
text_list = [text] |
|
for tok in tok_list: |
|
tokenized_text = [] |
|
for sub_text in text_list: |
|
if sub_text not in self.unique_no_split_tokens: |
|
tokenized_text.extend(split_on_token(tok, sub_text)) |
|
else: |
|
tokenized_text.append(sub_text) |
|
text_list = tokenized_text |
|
|
|
return list( |
|
itertools.chain.from_iterable( |
|
( |
|
self._tokenize(token) |
|
if token not in self.unique_no_split_tokens |
|
else [token] |
|
for token in tokenized_text |
|
) |
|
) |
|
) |
|
|
|
no_split_token = self.unique_no_split_tokens |
|
tokenized_text = split_on_tokens(no_split_token, text) |
|
return tokenized_text |
|
|
|
def encode(self, text): |
|
return [self.tok_to_idx[tok] for tok in self.tokenize(text)] |
|
|
|
|
|
class BatchConverter(object): |
|
"""Callable to convert an unprocessed (labels + strings) batch to a |
|
processed (labels + tensor) batch. |
|
""" |
|
|
|
def __init__(self, alphabet, truncation_seq_length: int = None): |
|
self.alphabet = alphabet |
|
self.truncation_seq_length = truncation_seq_length |
|
|
|
def __call__(self, raw_batch: Sequence[Tuple[str, str]]): |
|
|
|
batch_size = len(raw_batch) |
|
batch_labels, seq_str_list = zip(*raw_batch) |
|
seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list] |
|
if self.truncation_seq_length: |
|
seq_encoded_list = [seq_str[:self.truncation_seq_length] for seq_str in seq_encoded_list] |
|
max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list) |
|
tokens = torch.empty( |
|
( |
|
batch_size, |
|
max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos), |
|
), |
|
dtype=torch.int64, |
|
) |
|
tokens.fill_(self.alphabet.padding_idx) |
|
labels = [] |
|
strs = [] |
|
|
|
for i, (label, seq_str, seq_encoded) in enumerate( |
|
zip(batch_labels, seq_str_list, seq_encoded_list) |
|
): |
|
labels.append(label) |
|
strs.append(seq_str) |
|
if self.alphabet.prepend_bos: |
|
tokens[i, 0] = self.alphabet.cls_idx |
|
seq = torch.tensor(seq_encoded, dtype=torch.int64) |
|
tokens[ |
|
i, |
|
int(self.alphabet.prepend_bos) : len(seq_encoded) |
|
+ int(self.alphabet.prepend_bos), |
|
] = seq |
|
if self.alphabet.append_eos: |
|
tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx |
|
|
|
return labels, strs, tokens |
|
|
|
|
|
class MSABatchConverter(BatchConverter): |
|
def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA]): |
|
if isinstance(inputs[0][0], str): |
|
|
|
raw_batch: Sequence[RawMSA] = [inputs] |
|
else: |
|
raw_batch = inputs |
|
|
|
batch_size = len(raw_batch) |
|
max_alignments = max(len(msa) for msa in raw_batch) |
|
max_seqlen = max(len(msa[0][1]) for msa in raw_batch) |
|
|
|
tokens = torch.empty( |
|
( |
|
batch_size, |
|
max_alignments, |
|
max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos), |
|
), |
|
dtype=torch.int64, |
|
) |
|
tokens.fill_(self.alphabet.padding_idx) |
|
labels = [] |
|
strs = [] |
|
|
|
for i, msa in enumerate(raw_batch): |
|
msa_seqlens = set(len(seq) for _, seq in msa) |
|
if not len(msa_seqlens) == 1: |
|
raise RuntimeError( |
|
"Received unaligned sequences for input to MSA, all sequence " |
|
"lengths must be equal." |
|
) |
|
msa_labels, msa_strs, msa_tokens = super().__call__(msa) |
|
labels.append(msa_labels) |
|
strs.append(msa_strs) |
|
tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens |
|
|
|
return labels, strs, tokens |
|
|
|
|
|
def read_fasta( |
|
path, |
|
keep_gaps=True, |
|
keep_insertions=True, |
|
to_upper=False, |
|
): |
|
with open(path, "r") as f: |
|
for result in read_alignment_lines( |
|
f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper |
|
): |
|
yield result |
|
|
|
|
|
def read_alignment_lines( |
|
lines, |
|
keep_gaps=True, |
|
keep_insertions=True, |
|
to_upper=False, |
|
): |
|
seq = desc = None |
|
|
|
def parse(s): |
|
if not keep_gaps: |
|
s = re.sub("-", "", s) |
|
if not keep_insertions: |
|
s = re.sub("[a-z]", "", s) |
|
return s.upper() if to_upper else s |
|
|
|
for line in lines: |
|
|
|
if len(line) > 0 and line[0] == ">": |
|
if seq is not None: |
|
yield desc, parse(seq) |
|
desc = line.strip().lstrip(">") |
|
seq = "" |
|
else: |
|
assert isinstance(seq, str) |
|
seq += line.strip() |
|
assert isinstance(seq, str) and isinstance(desc, str) |
|
yield desc, parse(seq) |
|
|
|
|
|
class ESMStructuralSplitDataset(torch.utils.data.Dataset): |
|
""" |
|
Structural Split Dataset as described in section A.10 of the supplement of our paper. |
|
https://doi.org/10.1101/622803 |
|
|
|
We use the full version of SCOPe 2.07, clustered at 90% sequence identity, |
|
generated on January 23, 2020. |
|
|
|
For each SCOPe domain: |
|
- We extract the sequence from the corresponding PDB file |
|
- We extract the 3D coordinates of the Carbon beta atoms, aligning them |
|
to the sequence. We put NaN where Cb atoms are missing. |
|
- From the 3D coordinates, we calculate a pairwise distance map, based |
|
on L2 distance |
|
- We use DSSP to generate secondary structure labels for the corresponding |
|
PDB file. This is also aligned to the sequence. We put - where SSP |
|
labels are missing. |
|
|
|
For each SCOPe classification level of family/superfamily/fold (in order of difficulty), |
|
we have split the data into 5 partitions for cross validation. These are provided |
|
in a downloaded splits folder, in the format: |
|
splits/{split_level}/{cv_partition}/{train|valid}.txt |
|
where train is the partition and valid is the concatentation of the remaining 4. |
|
|
|
For each SCOPe domain, we provide a pkl dump that contains: |
|
- seq : The domain sequence, stored as an L-length string |
|
- ssp : The secondary structure labels, stored as an L-length string |
|
- dist : The distance map, stored as an LxL numpy array |
|
- coords : The 3D coordinates, stored as an Lx3 numpy array |
|
|
|
""" |
|
|
|
base_folder = "structural-data" |
|
file_list = [ |
|
|
|
( |
|
"https://dl.fbaipublicfiles.com/fair-esm/structural-data/splits.tar.gz", |
|
"splits.tar.gz", |
|
"splits", |
|
"456fe1c7f22c9d3d8dfe9735da52411d", |
|
), |
|
( |
|
"https://dl.fbaipublicfiles.com/fair-esm/structural-data/pkl.tar.gz", |
|
"pkl.tar.gz", |
|
"pkl", |
|
"644ea91e56066c750cd50101d390f5db", |
|
), |
|
] |
|
|
|
def __init__( |
|
self, |
|
split_level, |
|
cv_partition, |
|
split, |
|
root_path=os.path.expanduser("~/.cache/torch/data/esm"), |
|
download=False, |
|
): |
|
super().__init__() |
|
assert split in [ |
|
"train", |
|
"valid", |
|
], "train_valid must be 'train' or 'valid'" |
|
self.root_path = root_path |
|
self.base_path = os.path.join(self.root_path, self.base_folder) |
|
|
|
|
|
if download: |
|
self.download() |
|
|
|
self.split_file = os.path.join( |
|
self.base_path, "splits", split_level, cv_partition, f"{split}.txt" |
|
) |
|
self.pkl_dir = os.path.join(self.base_path, "pkl") |
|
self.names = [] |
|
with open(self.split_file) as f: |
|
self.names = f.read().splitlines() |
|
|
|
def __len__(self): |
|
return len(self.names) |
|
|
|
def _check_exists(self) -> bool: |
|
for (_, _, filename, _) in self.file_list: |
|
fpath = os.path.join(self.base_path, filename) |
|
if not os.path.exists(fpath) or not os.path.isdir(fpath): |
|
return False |
|
return True |
|
|
|
def download(self): |
|
|
|
if self._check_exists(): |
|
print("Files already downloaded and verified") |
|
return |
|
|
|
from torchvision.datasets.utils import download_url |
|
|
|
for url, tar_filename, filename, md5_hash in self.file_list: |
|
download_path = os.path.join(self.base_path, tar_filename) |
|
download_url(url=url, root=self.base_path, filename=tar_filename, md5=md5_hash) |
|
shutil.unpack_archive(download_path, self.base_path) |
|
|
|
def __getitem__(self, idx): |
|
""" |
|
Returns a dict with the following entires |
|
- seq : Str (domain sequence) |
|
- ssp : Str (SSP labels) |
|
- dist : np.array (distance map) |
|
- coords : np.array (3D coordinates) |
|
""" |
|
name = self.names[idx] |
|
pkl_fname = os.path.join(self.pkl_dir, name[1:3], f"{name}.pkl") |
|
with open(pkl_fname, "rb") as f: |
|
obj = pickle.load(f) |
|
return obj |
|
|