diffdock / confidence /dataset.py
gcorso's picture
first commit
4a3f787
raw
history blame
No virus
17.6 kB
import itertools
import math
import os
import pickle
import random
from argparse import Namespace
from functools import partial
import copy
import numpy as np
import pandas as pd
import torch
import yaml
from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from datasets.pdbbind import PDBBind
from utils.diffusion_utils import get_t_schedule
from utils.sampling import randomize_position, sampling
from utils.utils import get_model
from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl
class ListDataset(Dataset):
def __init__(self, list):
super().__init__()
self.data_list = list
def len(self) -> int:
return len(self.data_list)
def get(self, idx: int) -> Data:
return self.data_list[idx]
def get_cache_path(args, split):
cache_path = args.cache_path
if not args.no_torsion:
cache_path += '_torsion'
if args.all_atoms:
cache_path += '_allatoms'
split_path = args.split_train if split == 'train' else args.split_val
cache_path = os.path.join(cache_path, f'limit{args.limit_complexes}_INDEX{os.path.splitext(os.path.basename(split_path))[0]}_maxLigSize{args.max_lig_size}_H{int(not args.remove_hs)}_recRad{args.receptor_radius}_recMax{args.c_alpha_max_neighbors}'
+ ('' if not args.all_atoms else f'_atomRad{args.atom_radius}_atomMax{args.atom_max_neighbors}')
+ ('' if args.no_torsion or args.num_conformers == 1 else
f'_confs{args.num_conformers}')
+ ('' if args.esm_embeddings_path is None else f'_esmEmbeddings'))
return cache_path
def get_args_and_cache_path(original_model_dir, split):
with open(f'{original_model_dir}/model_parameters.yml') as f:
model_args = Namespace(**yaml.full_load(f))
return model_args, get_cache_path(model_args,split)
class ConfidenceDataset(Dataset):
def __init__(self, cache_path, original_model_dir, split, device, limit_complexes,
inference_steps, samples_per_complex, all_atoms,
args, balance=False, use_original_model_cache=True, rmsd_classification_cutoff=2,
cache_ids_to_combine= None, cache_creation_id=None):
super(ConfidenceDataset, self).__init__()
self.device = device
self.inference_steps = inference_steps
self.limit_complexes = limit_complexes
self.all_atoms = all_atoms
self.original_model_dir = original_model_dir
self.balance = balance
self.use_original_model_cache = use_original_model_cache
self.rmsd_classification_cutoff = rmsd_classification_cutoff
self.cache_ids_to_combine = cache_ids_to_combine
self.cache_creation_id = cache_creation_id
self.samples_per_complex = samples_per_complex
self.original_model_args, original_model_cache = get_args_and_cache_path(original_model_dir, split)
self.complex_graphs_cache = original_model_cache if self.use_original_model_cache else get_cache_path(args, split)
print('Using the cached complex graphs of the original model args' if self.use_original_model_cache else 'Not using the cached complex graphs of the original model args. Instead the complex graphs are used that are at the location given by the dataset parameters given to confidence_train.py')
print(self.complex_graphs_cache)
if not os.path.exists(os.path.join(self.complex_graphs_cache, "heterographs.pkl")):
print(f'HAPPENING | Complex graphs path does not exist yet: {os.path.join(self.complex_graphs_cache, "heterographs.pkl")}. For that reason, we are now creating the dataset.')
PDBBind(transform=None, root=args.data_dir, limit_complexes=args.limit_complexes,
receptor_radius=args.receptor_radius,
cache_path=args.cache_path, split_path=args.split_val if split == 'val' else args.split_train,
remove_hs=args.remove_hs, max_lig_size=None,
c_alpha_max_neighbors=args.c_alpha_max_neighbors,
matching=not args.no_torsion, keep_original=True,
popsize=args.matching_popsize,
maxiter=args.matching_maxiter,
all_atoms=args.all_atoms,
atom_radius=args.atom_radius,
atom_max_neighbors=args.atom_max_neighbors,
esm_embeddings_path=args.esm_embeddings_path,
require_ligand=True)
print(f'HAPPENING | Loading complex graphs from: {os.path.join(self.complex_graphs_cache, "heterographs.pkl")}')
with open(os.path.join(self.complex_graphs_cache, "heterographs.pkl"), 'rb') as f:
complex_graphs = pickle.load(f)
self.complex_graph_dict = {d.name: d for d in complex_graphs}
self.full_cache_path = os.path.join(cache_path, f'model_{os.path.splitext(os.path.basename(original_model_dir))[0]}'
f'_split_{split}_limit_{limit_complexes}')
if (not os.path.exists(os.path.join(self.full_cache_path, "ligand_positions.pkl")) and self.cache_creation_id is None) or \
(not os.path.exists(os.path.join(self.full_cache_path, f"ligand_positions_id{self.cache_creation_id}.pkl")) and self.cache_creation_id is not None):
os.makedirs(self.full_cache_path, exist_ok=True)
self.preprocessing(original_model_cache)
if self.cache_ids_to_combine is None:
print(f'HAPPENING | Loading positions and rmsds from: {os.path.join(self.full_cache_path, "ligand_positions.pkl")}')
with open(os.path.join(self.full_cache_path, "ligand_positions.pkl"), 'rb') as f:
self.full_ligand_positions, self.rmsds = pickle.load(f)
if os.path.exists(os.path.join(self.full_cache_path, "complex_names_in_same_order.pkl")):
with open(os.path.join(self.full_cache_path, "complex_names_in_same_order.pkl"), 'rb') as f:
generated_rmsd_complex_names = pickle.load(f)
else:
print('HAPPENING | The path, ', os.path.join(self.full_cache_path, "complex_names_in_same_order.pkl"),
' does not exist. \n => We assume that means that we are using a ligand_positions.pkl where the '
'code was not saving the complex names for them yet. We now instead use the complex names of '
'the dataset that the original model used to create the ligand positions and RMSDs.')
with open(os.path.join(original_model_cache, "heterographs.pkl"), 'rb') as f:
original_model_complex_graphs = pickle.load(f)
generated_rmsd_complex_names = [d.name for d in original_model_complex_graphs]
assert (len(self.rmsds) == len(generated_rmsd_complex_names))
else:
all_rmsds_unsorted, all_full_ligand_positions_unsorted, all_names_unsorted = [], [], []
for idx, cache_id in enumerate(self.cache_ids_to_combine):
print(f'HAPPENING | Loading positions and rmsds from cache_id from the path: {os.path.join(self.full_cache_path, "ligand_positions_"+ str(cache_id)+ ".pkl")}')
if not os.path.exists(os.path.join(self.full_cache_path, f"ligand_positions_id{cache_id}.pkl")): raise Exception(f'The generated ligand positions with cache_id do not exist: {cache_id}') # be careful with changing this error message since it is sometimes cought in a try catch
with open(os.path.join(self.full_cache_path, f"ligand_positions_id{cache_id}.pkl"), 'rb') as f:
full_ligand_positions, rmsds = pickle.load(f)
with open(os.path.join(self.full_cache_path, f"complex_names_in_same_order_id{cache_id}.pkl"), 'rb') as f:
names_unsorted = pickle.load(f)
all_names_unsorted.append(names_unsorted)
all_rmsds_unsorted.append(rmsds)
all_full_ligand_positions_unsorted.append(full_ligand_positions)
names_order = list(set(sum(all_names_unsorted, [])))
all_rmsds, all_full_ligand_positions, all_names = [], [], []
for idx, (rmsds_unsorted, full_ligand_positions_unsorted, names_unsorted) in enumerate(zip(all_rmsds_unsorted,all_full_ligand_positions_unsorted, all_names_unsorted)):
name_to_pos_dict = {name: (rmsd, pos) for name, rmsd, pos in zip(names_unsorted, full_ligand_positions_unsorted, rmsds_unsorted) }
intermediate_rmsds = [name_to_pos_dict[name][1] for name in names_order]
all_rmsds.append((intermediate_rmsds))
intermediate_pos = [name_to_pos_dict[name][0] for name in names_order]
all_full_ligand_positions.append((intermediate_pos))
self.full_ligand_positions, self.rmsds = [], []
for positions_tuple in list(zip(*all_full_ligand_positions)):
self.full_ligand_positions.append(np.concatenate(positions_tuple, axis=0))
for positions_tuple in list(zip(*all_rmsds)):
self.rmsds.append(np.concatenate(positions_tuple, axis=0))
generated_rmsd_complex_names = names_order
print('Number of complex graphs: ', len(self.complex_graph_dict))
print('Number of RMSDs and positions for the complex graphs: ', len(self.full_ligand_positions))
self.all_samples_per_complex = samples_per_complex * (1 if self.cache_ids_to_combine is None else len(self.cache_ids_to_combine))
self.positions_rmsds_dict = {name: (pos, rmsd) for name, pos, rmsd in zip (generated_rmsd_complex_names, self.full_ligand_positions, self.rmsds)}
self.dataset_names = list(set(self.positions_rmsds_dict.keys()) & set(self.complex_graph_dict.keys()))
if limit_complexes > 0:
self.dataset_names = self.dataset_names[:limit_complexes]
def len(self):
return len(self.dataset_names)
def get(self, idx):
complex_graph = copy.deepcopy(self.complex_graph_dict[self.dataset_names[idx]])
positions, rmsds = self.positions_rmsds_dict[self.dataset_names[idx]]
if self.balance:
if isinstance(self.rmsd_classification_cutoff, list): raise ValueError("a list for --rmsd_classification_cutoff can only be used without --balance")
label = random.randint(0, 1)
success = rmsds < self.rmsd_classification_cutoff
n_success = np.count_nonzero(success)
if label == 0 and n_success != self.all_samples_per_complex:
# sample negative complex
sample = random.randint(0, self.all_samples_per_complex - n_success - 1)
lig_pos = positions[~success][sample]
complex_graph['ligand'].pos = torch.from_numpy(lig_pos)
else:
# sample positive complex
if n_success > 0: # if no successfull sample returns the matched complex
sample = random.randint(0, n_success - 1)
lig_pos = positions[success][sample]
complex_graph['ligand'].pos = torch.from_numpy(lig_pos)
complex_graph.y = torch.tensor(label).float()
else:
sample = random.randint(0, self.all_samples_per_complex - 1)
complex_graph['ligand'].pos = torch.from_numpy(positions[sample])
complex_graph.y = torch.tensor(rmsds[sample] < self.rmsd_classification_cutoff).float().unsqueeze(0)
if isinstance(self.rmsd_classification_cutoff, list):
complex_graph.y_binned = torch.tensor(np.logical_and(rmsds[sample] < self.rmsd_classification_cutoff + [math.inf],rmsds[sample] >= [0] + self.rmsd_classification_cutoff), dtype=torch.float).unsqueeze(0)
complex_graph.y = torch.tensor(rmsds[sample] < self.rmsd_classification_cutoff[0]).unsqueeze(0).float()
complex_graph.rmsd = torch.tensor(rmsds[sample]).unsqueeze(0).float()
complex_graph['ligand'].node_t = {'tr': 0 * torch.ones(complex_graph['ligand'].num_nodes),
'rot': 0 * torch.ones(complex_graph['ligand'].num_nodes),
'tor': 0 * torch.ones(complex_graph['ligand'].num_nodes)}
complex_graph['receptor'].node_t = {'tr': 0 * torch.ones(complex_graph['receptor'].num_nodes),
'rot': 0 * torch.ones(complex_graph['receptor'].num_nodes),
'tor': 0 * torch.ones(complex_graph['receptor'].num_nodes)}
if self.all_atoms:
complex_graph['atom'].node_t = {'tr': 0 * torch.ones(complex_graph['atom'].num_nodes),
'rot': 0 * torch.ones(complex_graph['atom'].num_nodes),
'tor': 0 * torch.ones(complex_graph['atom'].num_nodes)}
complex_graph.complex_t = {'tr': 0 * torch.ones(1), 'rot': 0 * torch.ones(1), 'tor': 0 * torch.ones(1)}
return complex_graph
def preprocessing(self, original_model_cache):
t_to_sigma = partial(t_to_sigma_compl, args=self.original_model_args)
model = get_model(self.original_model_args, self.device, t_to_sigma=t_to_sigma, no_parallel=True)
state_dict = torch.load(f'{self.original_model_dir}/best_model.pt', map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=True)
model = model.to(self.device)
model.eval()
tr_schedule = get_t_schedule(inference_steps=self.inference_steps)
rot_schedule = tr_schedule
tor_schedule = tr_schedule
print('common t schedule', tr_schedule)
print('HAPPENING | loading cached complexes of the original model to create the confidence dataset RMSDs and predicted positions. Doing that from: ', os.path.join(self.complex_graphs_cache, "heterographs.pkl"))
with open(os.path.join(original_model_cache, "heterographs.pkl"), 'rb') as f:
complex_graphs = pickle.load(f)
dataset = ListDataset(complex_graphs)
loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)
rmsds, full_ligand_positions, names = [], [], []
for idx, orig_complex_graph in tqdm(enumerate(loader)):
data_list = [copy.deepcopy(orig_complex_graph) for _ in range(self.samples_per_complex)]
randomize_position(data_list, self.original_model_args.no_torsion, False, self.original_model_args.tr_sigma_max)
predictions_list = None
failed_convergence_counter = 0
while predictions_list is None:
try:
predictions_list, confidences = sampling(data_list=data_list, model=model, inference_steps=self.inference_steps,
tr_schedule=tr_schedule, rot_schedule=rot_schedule, tor_schedule=tor_schedule,
device=self.device, t_to_sigma=t_to_sigma, model_args=self.original_model_args)
except Exception as e:
if 'failed to converge' in str(e):
failed_convergence_counter += 1
if failed_convergence_counter > 5:
print('| WARNING: SVD failed to converge 5 times - skipping the complex')
break
print('| WARNING: SVD failed to converge - trying again with a new sample')
else:
raise e
if failed_convergence_counter > 5: predictions_list = data_list
if self.original_model_args.no_torsion:
orig_complex_graph['ligand'].orig_pos = (orig_complex_graph['ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy())
filterHs = torch.not_equal(predictions_list[0]['ligand'].x[:, 0], 0).cpu().numpy()
if isinstance(orig_complex_graph['ligand'].orig_pos, list):
orig_complex_graph['ligand'].orig_pos = orig_complex_graph['ligand'].orig_pos[0]
ligand_pos = np.asarray([complex_graph['ligand'].pos.cpu().numpy()[filterHs] for complex_graph in predictions_list])
orig_ligand_pos = np.expand_dims(orig_complex_graph['ligand'].orig_pos[filterHs] - orig_complex_graph.original_center.cpu().numpy(), axis=0)
rmsd = np.sqrt(((ligand_pos - orig_ligand_pos) ** 2).sum(axis=2).mean(axis=1))
rmsds.append(rmsd)
full_ligand_positions.append(np.asarray([complex_graph['ligand'].pos.cpu().numpy() for complex_graph in predictions_list]))
names.append(orig_complex_graph.name[0])
assert(len(orig_complex_graph.name) == 1) # I just put this assert here because of the above line where I assumed that the list is always only lenght 1. Just in case it isn't maybe check what the names in there are.
with open(os.path.join(self.full_cache_path, f"ligand_positions{'' if self.cache_creation_id is None else '_id' + str(self.cache_creation_id)}.pkl"), 'wb') as f:
pickle.dump((full_ligand_positions, rmsds), f)
with open(os.path.join(self.full_cache_path, f"complex_names_in_same_order{'' if self.cache_creation_id is None else '_id' + str(self.cache_creation_id)}.pkl"), 'wb') as f:
pickle.dump((names), f)