|
|
import io |
|
|
import random |
|
|
import warnings |
|
|
import torch |
|
|
import webdataset as wds |
|
|
|
|
|
from pathlib import Path |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
from src.data.data_utils import TensorDict, collate_entity |
|
|
from src.constants import WEBDATASET_SHARD_SIZE, WEBDATASET_VAL_SIZE |
|
|
|
|
|
|
|
|
class ProcessedLigandPocketDataset(Dataset): |
|
|
def __init__(self, pt_path, ligand_transform=None, pocket_transform=None, |
|
|
catch_errors=False): |
|
|
|
|
|
self.ligand_transform = ligand_transform |
|
|
self.pocket_transform = pocket_transform |
|
|
self.catch_errors = catch_errors |
|
|
self.pt_path = pt_path |
|
|
|
|
|
self.data = torch.load(pt_path) |
|
|
|
|
|
|
|
|
for entity in ['ligands', 'pockets']: |
|
|
self.data[entity]['size'] = torch.tensor([len(x) for x in self.data[entity]['x']]) |
|
|
self.data[entity]['n_bonds'] = torch.tensor([len(x) for x in self.data[entity]['bond_one_hot']]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data['ligands']['name']) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
data = {} |
|
|
data['ligand'] = {key: val[idx] for key, val in self.data['ligands'].items()} |
|
|
data['pocket'] = {key: val[idx] for key, val in self.data['pockets'].items()} |
|
|
try: |
|
|
if self.ligand_transform is not None: |
|
|
data['ligand'] = self.ligand_transform(data['ligand']) |
|
|
if self.pocket_transform is not None: |
|
|
data['pocket'] = self.pocket_transform(data['pocket']) |
|
|
except (RuntimeError, ValueError) as e: |
|
|
if self.catch_errors: |
|
|
warnings.warn(f"{type(e).__name__}('{e}') in data transform. " |
|
|
f"Returning random item instead") |
|
|
|
|
|
rand_idx = random.randint(0, len(self) - 1) |
|
|
return self[rand_idx] |
|
|
else: |
|
|
raise e |
|
|
return data |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch_pairs, ligand_transform=None): |
|
|
|
|
|
out = {} |
|
|
for entity in ['ligand', 'pocket']: |
|
|
batch = [x[entity] for x in batch_pairs] |
|
|
|
|
|
if entity == 'ligand' and ligand_transform is not None: |
|
|
max_size = max(x['size'].item() for x in batch) |
|
|
|
|
|
batch = [ligand_transform(x, max_size=max_size) for x in batch] |
|
|
|
|
|
out[entity] = TensorDict(**collate_entity(batch)) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
class ClusteredDataset(ProcessedLigandPocketDataset): |
|
|
def __init__(self, pt_path, ligand_transform=None, pocket_transform=None, |
|
|
catch_errors=False): |
|
|
super().__init__(pt_path, ligand_transform, pocket_transform, catch_errors) |
|
|
self.clusters = list(self.data['clusters'].values()) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.clusters) |
|
|
|
|
|
def __getitem__(self, cidx): |
|
|
cluster_inds = self.clusters[cidx] |
|
|
|
|
|
idx = random.choice(cluster_inds) |
|
|
return super().__getitem__(idx) |
|
|
|
|
|
class DPODataset(ProcessedLigandPocketDataset): |
|
|
def __init__(self, pt_path, ligand_transform=None, pocket_transform=None, |
|
|
catch_errors=False): |
|
|
self.ligand_transform = ligand_transform |
|
|
self.pocket_transform = pocket_transform |
|
|
self.catch_errors = catch_errors |
|
|
self.pt_path = pt_path |
|
|
|
|
|
self.data = torch.load(pt_path) |
|
|
|
|
|
if not 'pockets' in self.data: |
|
|
self.data['pockets'] = self.data['pockets_w'] |
|
|
if not 'ligands' in self.data: |
|
|
self.data['ligands'] = self.data['ligands_w'] |
|
|
|
|
|
if ( |
|
|
len(self.data["ligands"]["name"]) |
|
|
!= len(self.data["ligands_l"]["name"]) |
|
|
!= len(self.data["pockets"]["name"]) |
|
|
): |
|
|
raise ValueError( |
|
|
"Error while importing DPO Dataset: Number of ligands winning, ligands losing and pockets must be the same" |
|
|
) |
|
|
|
|
|
|
|
|
for entity in ['ligands', 'ligands_l', 'pockets']: |
|
|
self.data[entity]['size'] = torch.tensor([len(x) for x in self.data[entity]['x']]) |
|
|
self.data[entity]['n_bonds'] = torch.tensor([len(x) for x in self.data[entity]['bond_one_hot']]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data["ligands"]["name"]) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
data = {} |
|
|
data['ligand'] = {key: val[idx] for key, val in self.data['ligands'].items()} |
|
|
data['ligand_l'] = {key: val[idx] for key, val in self.data['ligands_l'].items()} |
|
|
data['pocket'] = {key: val[idx] for key, val in self.data['pockets'].items()} |
|
|
try: |
|
|
if self.ligand_transform is not None: |
|
|
data['ligand'] = self.ligand_transform(data['ligand']) |
|
|
data['ligand_l'] = self.ligand_transform(data['ligand_l']) |
|
|
if self.pocket_transform is not None: |
|
|
data['pocket'] = self.pocket_transform(data['pocket']) |
|
|
except (RuntimeError, ValueError) as e: |
|
|
if self.catch_errors: |
|
|
warnings.warn(f"{type(e).__name__}('{e}') in data transform. " |
|
|
f"Returning random item instead") |
|
|
|
|
|
rand_idx = random.randint(0, len(self) - 1) |
|
|
return self[rand_idx] |
|
|
else: |
|
|
raise e |
|
|
return data |
|
|
|
|
|
@staticmethod |
|
|
def collate_fn(batch_pairs, ligand_transform=None): |
|
|
|
|
|
out = {} |
|
|
for entity in ['ligand', 'ligand_l', 'pocket']: |
|
|
batch = [x[entity] for x in batch_pairs] |
|
|
|
|
|
if entity in ['ligand', 'ligand_l'] and ligand_transform is not None: |
|
|
max_size = max(x['size'].item() for x in batch) |
|
|
batch = [ligand_transform(x, max_size=max_size) for x in batch] |
|
|
|
|
|
out[entity] = TensorDict(**collate_entity(batch)) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ProteinLigandWebDataset(wds.WebDataset): |
|
|
@staticmethod |
|
|
def collate_fn(batch_pairs, ligand_transform=None): |
|
|
return ProcessedLigandPocketDataset.collate_fn(batch_pairs, ligand_transform) |
|
|
|
|
|
|
|
|
def wds_decoder(key, value): |
|
|
return torch.load(io.BytesIO(value)) |
|
|
|
|
|
|
|
|
def preprocess_wds_item(data): |
|
|
out = {} |
|
|
for entity in ['ligand', 'pocket']: |
|
|
out[entity] = data['pt'][entity] |
|
|
for attr in ['size', 'n_bonds']: |
|
|
if torch.is_tensor(out[entity][attr]): |
|
|
assert len(out[entity][attr]) == 0 |
|
|
out[entity][attr] = 0 |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
def get_wds(data_path, stage, ligand_transform=None, pocket_transform=None): |
|
|
current_data_dir = Path(data_path, stage) |
|
|
shards = sorted(current_data_dir.glob('shard-?????.tar'), key=lambda s: int(s.name.split('-')[-1].split('.')[0])) |
|
|
min_shard = min(shards).name.split('-')[-1].split('.')[0] |
|
|
max_shard = max(shards).name.split('-')[-1].split('.')[0] |
|
|
total_size = (int(max_shard) - int(min_shard) + 1) * WEBDATASET_SHARD_SIZE if stage == 'train' else WEBDATASET_VAL_SIZE |
|
|
|
|
|
url = f'{data_path}/{stage}/shard-{{{min_shard}..{max_shard}}}.tar' |
|
|
ligand_transform_wrapper = lambda _data: _data |
|
|
pocket_transform_wrapper = lambda _data: _data |
|
|
|
|
|
if ligand_transform is not None: |
|
|
def ligand_transform_wrapper(_data): |
|
|
_data['pt']['ligand'] = ligand_transform(_data['pt']['ligand']) |
|
|
return _data |
|
|
|
|
|
if pocket_transform is not None: |
|
|
def pocket_transform_wrapper(_data): |
|
|
_data['pt']['pocket'] = pocket_transform(_data['pt']['pocket']) |
|
|
return _data |
|
|
|
|
|
return ( |
|
|
ProteinLigandWebDataset(url, nodesplitter=wds.split_by_node) |
|
|
.decode(wds_decoder) |
|
|
.map(ligand_transform_wrapper) |
|
|
.map(pocket_transform_wrapper) |
|
|
.map(preprocess_wds_item) |
|
|
.with_length(total_size) |
|
|
) |
|
|
|