Holmes
test
ca7299e
"""Protein dataset class."""
import os
import pickle
from pathlib import Path
from glob import glob
from typing import Optional, Sequence, List, Union
from functools import lru_cache
import tree
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
from src.common import residue_constants, data_transforms, rigid_utils, protein
CA_IDX = residue_constants.atom_order['CA']
DTYPE_MAPPING = {
'aatype': torch.long,
'atom_positions': torch.double,
'atom_mask': torch.double,
}
class ProteinFeatureTransform:
def __init__(self,
unit: Optional[str] = 'angstrom',
truncate_length: Optional[int] = None,
strip_missing_residues: bool = True,
recenter_and_scale: bool = True,
eps: float = 1e-8,
):
if unit == 'angstrom':
self.coordinate_scale = 1.0
elif unit in ('nm', 'nanometer'):
self.coordiante_scale = 0.1
else:
raise ValueError(f"Invalid unit: {unit}")
if truncate_length is not None:
assert truncate_length > 0, f"Invalid truncate_length: {truncate_length}"
self.truncate_length = truncate_length
self.strip_missing_residues = strip_missing_residues
self.recenter_and_scale = recenter_and_scale
self.eps = eps
def __call__(self, chain_feats):
chain_feats = self.patch_feats(chain_feats)
if self.strip_missing_residues:
chain_feats = self.strip_ends(chain_feats)
if self.truncate_length is not None:
chain_feats = self.random_truncate(chain_feats, max_len=self.truncate_length)
# Recenter and scale atom positions
if self.recenter_and_scale:
chain_feats = self.recenter_and_scale_coords(chain_feats, coordinate_scale=self.coordinate_scale, eps=self.eps)
# Map to torch Tensor
chain_feats = self.map_to_tensors(chain_feats)
# Add extra features from AF2
chain_feats = self.protein_data_transform(chain_feats)
# ** refer to line 170 in pdb_data_loader.py **
return chain_feats
@staticmethod
def patch_feats(chain_feats):
seq_mask = chain_feats['atom_mask'][:, CA_IDX] # a little hack here
# residue_idx = np.arange(seq_mask.shape[0], dtype=np.int64)
residue_idx = chain_feats['residue_index'] - np.min(chain_feats['residue_index']) # start from 0, possibly has chain break
patch_feats = {
'seq_mask': seq_mask,
'residue_mask': seq_mask,
'residue_idx': residue_idx,
'fixed_mask': np.zeros_like(seq_mask),
'sc_ca_t': np.zeros(seq_mask.shape + (3, )),
}
chain_feats.update(patch_feats)
return chain_feats
@staticmethod
def strip_ends(chain_feats):
# Strip missing residues on both ends
modeled_idx = np.where(chain_feats['aatype'] != 20)[0]
min_idx, max_idx = np.min(modeled_idx), np.max(modeled_idx)
chain_feats = tree.map_structure(
lambda x: x[min_idx : (max_idx+1)], chain_feats)
return chain_feats
@staticmethod
def random_truncate(chain_feats, max_len):
L = chain_feats['aatype'].shape[0]
if L > max_len:
# Randomly truncate
start = np.random.randint(0, L - max_len + 1)
end = start + max_len
chain_feats = tree.map_structure(
lambda x: x[start : end], chain_feats)
return chain_feats
@staticmethod
def map_to_tensors(chain_feats):
chain_feats = {k: torch.as_tensor(v) for k,v in chain_feats.items()}
# Alter dtype
for k, dtype in DTYPE_MAPPING.items():
if k in chain_feats:
chain_feats[k] = chain_feats[k].type(dtype)
return chain_feats
@staticmethod
def recenter_and_scale_coords(chain_feats, coordinate_scale, eps=1e-8):
# recenter and scale atom positions
bb_pos = chain_feats['atom_positions'][:, CA_IDX]
bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['seq_mask']) + eps)
centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :]
scaled_pos = centered_pos * coordinate_scale
chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None]
return chain_feats
@staticmethod
def protein_data_transform(chain_feats):
chain_feats.update(
{
"all_atom_positions": chain_feats["atom_positions"],
"all_atom_mask": chain_feats["atom_mask"],
}
)
chain_feats = data_transforms.atom37_to_frames(chain_feats)
chain_feats = data_transforms.atom37_to_torsion_angles("")(chain_feats)
chain_feats = data_transforms.get_backbone_frames(chain_feats)
chain_feats = data_transforms.get_chi_angles(chain_feats)
chain_feats = data_transforms.make_pseudo_beta("")(chain_feats)
chain_feats = data_transforms.make_atom14_masks(chain_feats)
chain_feats = data_transforms.make_atom14_positions(chain_feats)
# Add convenient key
chain_feats.pop("all_atom_positions")
chain_feats.pop("all_atom_mask")
return chain_feats
class MetadataFilter:
def __init__(self,
min_len: Optional[int] = None,
max_len: Optional[int] = None,
min_chains: Optional[int] = None,
max_chains: Optional[int] = None,
min_resolution: Optional[int] = None,
max_resolution: Optional[int] = None,
include_structure_method: Optional[List[str]] = None,
include_oligomeric_detail: Optional[List[str]] = None,
**kwargs,
):
self.min_len = min_len
self.max_len = max_len
self.min_chains = min_chains
self.max_chains = max_chains
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.include_structure_method = include_structure_method
self.include_oligomeric_detail = include_oligomeric_detail
def __call__(self, df):
_pre_filter_len = len(df)
if self.min_len is not None:
df = df[df['raw_seq_len'] >= self.min_len]
if self.max_len is not None:
df = df[df['raw_seq_len'] <= self.max_len]
if self.min_chains is not None:
df = df[df['num_chains'] >= self.min_chains]
if self.max_chains is not None:
df = df[df['num_chains'] <= self.max_chains]
if self.min_resolution is not None:
df = df[df['resolution'] >= self.min_resolution]
if self.max_resolution is not None:
df = df[df['resolution'] <= self.max_resolution]
if self.include_structure_method is not None:
df = df[df['include_structure_method'].isin(self.include_structure_method)]
if self.include_oligomeric_detail is not None:
df = df[df['include_oligomeric_detail'].isin(self.include_oligomeric_detail)]
print(f">>> Filter out {len(df)} samples out of {_pre_filter_len} by the metadata filter")
return df
class RandomAccessProteinDataset(torch.utils.data.Dataset):
"""Random access to pickle protein objects of dataset.
dict_keys(['atom_positions', 'aatype', 'atom_mask', 'residue_index', 'chain_index', 'b_factors'])
Note that each value is a ndarray in shape (L, *), for example:
'atom_positions': (L, 37, 3)
"""
def __init__(self,
path_to_dataset: Union[Path, str],
path_to_seq_embedding: Optional[Path] = None,
metadata_filter: Optional[MetadataFilter] = None,
training: bool = True,
transform: Optional[ProteinFeatureTransform] = None,
suffix: Optional[str] = '.pkl',
accession_code_fillter: Optional[Sequence[str]] = None,
**kwargs,
):
super().__init__()
path_to_dataset = os.path.expanduser(path_to_dataset)
suffix = suffix if suffix.startswith('.') else '.' + suffix
assert suffix in ('.pkl', '.pdb'), f"Invalid suffix: {suffix}"
if os.path.isfile(path_to_dataset): # path to csv file
assert path_to_dataset.endswith('.csv'), f"Invalid file extension: {path_to_dataset} (have to be .csv)"
self._df = pd.read_csv(path_to_dataset)
self._df.sort_values('modeled_seq_len', ascending=False)
if metadata_filter:
self._df = metadata_filter(self._df)
self._data = self._df['processed_complex_path'].tolist()
elif os.path.isdir(path_to_dataset): # path to directory
self._data = sorted(glob(os.path.join(path_to_dataset, '*' + suffix)))
assert len(self._data) > 0, f"No {suffix} file found in '{path_to_dataset}'"
else: # path as glob pattern
_pattern = path_to_dataset
self._data = sorted(glob(_pattern))
assert len(self._data) > 0, f"No files found in '{_pattern}'"
if accession_code_fillter and len(accession_code_fillter) > 0:
self._data = [p for p in self._data
if np.isin(os.path.splitext(os.path.basename(p))[0], accession_code_fillter)
]
self.data = np.asarray(self._data)
self.path_to_seq_embedding = os.path.expanduser(path_to_seq_embedding) \
if path_to_seq_embedding is not None else None
self.suffix = suffix
self.transform = transform
self.training = training # not implemented yet
@property
def num_samples(self):
return len(self.data)
def len(self):
return self.__len__()
def __len__(self):
return self.num_samples
def get(self, idx):
return self.__getitem__(idx)
@lru_cache(maxsize=100)
def __getitem__(self, idx):
"""return single pyg.Data() instance
"""
data_path = self.data[idx]
accession_code = os.path.splitext(os.path.basename(data_path))[0]
if self.suffix == '.pkl':
# Load pickled protein
with open(data_path, 'rb') as f:
data_object = pickle.load(f)
elif self.suffix == '.pdb':
# Load pdb file
with open(data_path, 'r') as f:
pdb_string = f.read()
data_object = protein.from_pdb_string(pdb_string).to_dict()
# Apply data transform
if self.transform is not None:
data_object = self.transform(data_object)
# Get sequence embedding if have
if self.path_to_seq_embedding is not None:
embed_dict = torch.load(
os.path.join(self.path_to_seq_embedding, f"{accession_code}.pt")
)
data_object.update(
{
'seq_emb': embed_dict['representations'][33].float(),
} # 33 is for ESM650M
)
data_object['accession_code'] = accession_code
return data_object # dict of arrays
class PretrainPDBDataset(RandomAccessProteinDataset):
def __init__(self,
path_to_dataset: str,
metadata_filter: MetadataFilter,
transform: ProteinFeatureTransform,
**kwargs,
):
super(PretrainPDBDataset, self).__init__(path_to_dataset=path_to_dataset,
metadata_filter=metadata_filter,
transform=transform,
**kwargs,
)
class SamplingPDBDataset(RandomAccessProteinDataset):
def __init__(self,
path_to_dataset: str,
training: bool = False,
suffix: str = '.pdb',
transform: Optional[ProteinFeatureTransform] = None,
accession_code_fillter: Optional[Sequence[str]] = None,
):
assert os.path.isdir(path_to_dataset), f"Invalid path (expected to be directory): {path_to_dataset}"
super(SamplingPDBDataset, self).__init__(path_to_dataset=path_to_dataset,
training=training,
suffix=suffix,
transform=transform,
accession_code_fillter=accession_code_fillter,
metadata_filter=None,
)