DiffLinker / src /lightning.py
igashov
update device
d5b42eb
raw
history blame
19.7 kB
import numpy as np
import os
import pytorch_lightning as pl
import torch
import wandb
from src import metrics, utils, delinker
from src.const import LINKER_SIZE_DIST
from src.egnn import Dynamics, DynamicsWithPockets
from src.edm import EDM, InpaintingEDM
from src.datasets import (
ZincDataset, MOADDataset, create_templates_for_linker_generation, get_dataloader, collate
)
from src.linker_size import DistributionNodes
from src.molecule_builder import build_molecules
from src.visualizer import save_xyz_file, visualize_chain
from typing import Dict, List, Optional
from tqdm import tqdm
from pdb import set_trace
def get_activation(activation):
print(activation)
if activation == 'silu':
return torch.nn.SiLU()
else:
raise Exception("activation fn not supported yet. Add it here.")
class DDPM(pl.LightningModule):
train_dataset = None
val_dataset = None
test_dataset = None
starting_epoch = None
metrics: Dict[str, List[float]] = {}
FRAMES = 100
def __init__(
self,
in_node_nf, n_dims, context_node_nf, hidden_nf, activation, tanh, n_layers, attention, norm_constant,
inv_sublayers, sin_embedding, normalization_factor, aggregation_method,
diffusion_steps, diffusion_noise_schedule, diffusion_noise_precision, diffusion_loss_type,
normalize_factors, include_charges, model,
data_path, train_data_prefix, val_data_prefix, batch_size, lr, torch_device, test_epochs, n_stability_samples,
normalization=None, log_iterations=None, samples_dir=None, data_augmentation=False,
center_of_mass='fragments', inpainting=False, anchors_context=True,
):
super(DDPM, self).__init__()
self.save_hyperparameters()
self.data_path = data_path
self.train_data_prefix = train_data_prefix
self.val_data_prefix = val_data_prefix
self.batch_size = batch_size
self.lr = lr
self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.include_charges = include_charges
self.test_epochs = test_epochs
self.n_stability_samples = n_stability_samples
self.log_iterations = log_iterations
self.samples_dir = samples_dir
self.data_augmentation = data_augmentation
self.center_of_mass = center_of_mass
self.inpainting = inpainting
self.loss_type = diffusion_loss_type
self.n_dims = n_dims
self.num_classes = in_node_nf - include_charges
self.include_charges = include_charges
self.anchors_context = anchors_context
self.is_geom = ('geom' in self.train_data_prefix) or ('MOAD' in self.train_data_prefix)
if type(activation) is str:
activation = get_activation(activation)
dynamics_class = DynamicsWithPockets if '.' in train_data_prefix else Dynamics
dynamics = dynamics_class(
in_node_nf=in_node_nf,
n_dims=n_dims,
context_node_nf=context_node_nf,
device=self.torch_device,
hidden_nf=hidden_nf,
activation=activation,
n_layers=n_layers,
attention=attention,
tanh=tanh,
norm_constant=norm_constant,
inv_sublayers=inv_sublayers,
sin_embedding=sin_embedding,
normalization_factor=normalization_factor,
aggregation_method=aggregation_method,
model=model,
normalization=normalization,
centering=inpainting,
)
edm_class = InpaintingEDM if inpainting else EDM
self.edm = edm_class(
dynamics=dynamics,
in_node_nf=in_node_nf,
n_dims=n_dims,
timesteps=diffusion_steps,
noise_schedule=diffusion_noise_schedule,
noise_precision=diffusion_noise_precision,
loss_type=diffusion_loss_type,
norm_values=normalize_factors,
)
self.linker_size_sampler = DistributionNodes(LINKER_SIZE_DIST)
def setup(self, stage: Optional[str] = None):
dataset_type = MOADDataset if '.' in self.train_data_prefix else ZincDataset
if stage == 'fit':
self.is_geom = ('geom' in self.train_data_prefix) or ('MOAD' in self.train_data_prefix)
self.train_dataset = dataset_type(
data_path=self.data_path,
prefix=self.train_data_prefix,
device=self.torch_device
)
self.val_dataset = dataset_type(
data_path=self.data_path,
prefix=self.val_data_prefix,
device=self.torch_device
)
elif stage == 'val':
self.is_geom = ('geom' in self.val_data_prefix) or ('MOAD' in self.val_data_prefix)
self.val_dataset = dataset_type(
data_path=self.data_path,
prefix=self.val_data_prefix,
device=self.torch_device
)
else:
raise NotImplementedError
def train_dataloader(self, collate_fn=collate):
return get_dataloader(self.train_dataset, self.batch_size, collate_fn=collate_fn, shuffle=True)
def val_dataloader(self, collate_fn=collate):
return get_dataloader(self.val_dataset, self.batch_size, collate_fn=collate_fn)
def test_dataloader(self, collate_fn=collate):
return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_fn)
def forward(self, data, training):
x = data['positions']
h = data['one_hot']
node_mask = data['atom_mask']
edge_mask = data['edge_mask']
anchors = data['anchors']
fragment_mask = data['fragment_mask']
linker_mask = data['linker_mask']
# Anchors and fragments labels are used as context
if self.anchors_context:
context = torch.cat([anchors, fragment_mask], dim=-1)
else:
context = fragment_mask
# Add information about pocket to the context
if '.' in self.train_data_prefix:
fragment_pocket_mask = fragment_mask
fragment_only_mask = data['fragment_only_mask']
pocket_only_mask = fragment_pocket_mask - fragment_only_mask
if self.anchors_context:
context = torch.cat([anchors, fragment_only_mask, pocket_only_mask], dim=-1)
else:
context = torch.cat([fragment_only_mask, pocket_only_mask], dim=-1)
# Removing COM of fragment from the atom coordinates
if self.inpainting:
center_of_mass_mask = node_mask
elif self.center_of_mass == 'fragments':
center_of_mass_mask = fragment_mask
elif self.center_of_mass == 'anchors':
center_of_mass_mask = anchors
else:
raise NotImplementedError(self.center_of_mass)
x = utils.remove_partial_mean_with_mask(x, node_mask, center_of_mass_mask)
utils.assert_partial_mean_zero_with_mask(x, node_mask, center_of_mass_mask)
# Applying random rotation
if training and self.data_augmentation:
x = utils.random_rotation(x)
return self.edm.forward(
x=x,
h=h,
node_mask=node_mask,
fragment_mask=fragment_mask,
linker_mask=linker_mask,
edge_mask=edge_mask,
context=context
)
def training_step(self, data, *args):
delta_log_px, kl_prior, loss_term_t, loss_term_0, l2_loss, noise_t, noise_0 = self.forward(data, training=True)
vlb_loss = kl_prior + loss_term_t + loss_term_0 - delta_log_px
if self.loss_type == 'l2':
loss = l2_loss
elif self.loss_type == 'vlb':
loss = vlb_loss
else:
raise NotImplementedError(self.loss_type)
training_metrics = {
'loss': loss,
'delta_log_px': delta_log_px,
'kl_prior': kl_prior,
'loss_term_t': loss_term_t,
'loss_term_0': loss_term_0,
'l2_loss': l2_loss,
'vlb_loss': vlb_loss,
'noise_t': noise_t,
'noise_0': noise_0
}
if self.log_iterations is not None and self.global_step % self.log_iterations == 0:
for metric_name, metric in training_metrics.items():
self.metrics.setdefault(f'{metric_name}/train', []).append(metric)
self.log(f'{metric_name}/train', metric, prog_bar=True)
return training_metrics
def validation_step(self, data, *args):
delta_log_px, kl_prior, loss_term_t, loss_term_0, l2_loss, noise_t, noise_0 = self.forward(data, training=False)
vlb_loss = kl_prior + loss_term_t + loss_term_0 - delta_log_px
if self.loss_type == 'l2':
loss = l2_loss
elif self.loss_type == 'vlb':
loss = vlb_loss
else:
raise NotImplementedError(self.loss_type)
return {
'loss': loss,
'delta_log_px': delta_log_px,
'kl_prior': kl_prior,
'loss_term_t': loss_term_t,
'loss_term_0': loss_term_0,
'l2_loss': l2_loss,
'vlb_loss': vlb_loss,
'noise_t': noise_t,
'noise_0': noise_0
}
def test_step(self, data, *args):
delta_log_px, kl_prior, loss_term_t, loss_term_0, l2_loss, noise_t, noise_0 = self.forward(data, training=False)
vlb_loss = kl_prior + loss_term_t + loss_term_0 - delta_log_px
if self.loss_type == 'l2':
loss = l2_loss
elif self.loss_type == 'vlb':
loss = vlb_loss
else:
raise NotImplementedError(self.loss_type)
return {
'loss': loss,
'delta_log_px': delta_log_px,
'kl_prior': kl_prior,
'loss_term_t': loss_term_t,
'loss_term_0': loss_term_0,
'l2_loss': l2_loss,
'vlb_loss': vlb_loss,
'noise_t': noise_t,
'noise_0': noise_0
}
def training_epoch_end(self, training_step_outputs):
for metric in training_step_outputs[0].keys():
avg_metric = self.aggregate_metric(training_step_outputs, metric)
self.metrics.setdefault(f'{metric}/train', []).append(avg_metric)
self.log(f'{metric}/train', avg_metric, prog_bar=True)
def validation_epoch_end(self, validation_step_outputs):
for metric in validation_step_outputs[0].keys():
avg_metric = self.aggregate_metric(validation_step_outputs, metric)
self.metrics.setdefault(f'{metric}/val', []).append(avg_metric)
self.log(f'{metric}/val', avg_metric, prog_bar=True)
if (self.current_epoch + 1) % self.test_epochs == 0:
sampling_results = self.sample_and_analyze(self.val_dataloader())
for metric_name, metric_value in sampling_results.items():
self.log(f'{metric_name}/val', metric_value, prog_bar=True)
self.metrics.setdefault(f'{metric_name}/val', []).append(metric_value)
# Logging the results corresponding to the best validation_and_connectivity
best_metrics, best_epoch = self.compute_best_validation_metrics()
self.log('best_epoch', int(best_epoch), prog_bar=True, batch_size=self.batch_size)
for metric, value in best_metrics.items():
self.log(f'best_{metric}', value, prog_bar=True, batch_size=self.batch_size)
def test_epoch_end(self, test_step_outputs):
for metric in test_step_outputs[0].keys():
avg_metric = self.aggregate_metric(test_step_outputs, metric)
self.metrics.setdefault(f'{metric}/test', []).append(avg_metric)
self.log(f'{metric}/test', avg_metric, prog_bar=True)
if (self.current_epoch + 1) % self.test_epochs == 0:
sampling_results = self.sample_and_analyze(self.test_dataloader())
for metric_name, metric_value in sampling_results.items():
self.log(f'{metric_name}/test', metric_value, prog_bar=True)
self.metrics.setdefault(f'{metric_name}/test', []).append(metric_value)
def generate_animation(self, chain_batch, node_mask, batch_i):
batch_indices, mol_indices = utils.get_batch_idx_for_animation(self.batch_size, batch_i)
for bi, mi in zip(batch_indices, mol_indices):
chain = chain_batch[:, bi, :, :]
name = f'mol_{mi}'
chain_output = os.path.join(self.samples_dir, f'epoch_{self.current_epoch}', name)
os.makedirs(chain_output, exist_ok=True)
one_hot = chain[:, :, 3:-1] if self.include_charges else chain[:, :, 3:]
positions = chain[:, :, :3]
chain_node_mask = torch.cat([node_mask[bi].unsqueeze(0) for _ in range(self.FRAMES)], dim=0)
names = [f'{name}_{j}' for j in range(self.FRAMES)]
save_xyz_file(chain_output, one_hot, positions, chain_node_mask, names=names, is_geom=self.is_geom)
visualize_chain(chain_output, wandb=wandb, mode=name, is_geom=self.is_geom)
def sample_and_analyze(self, dataloader):
pred_molecules = []
true_molecules = []
true_fragments = []
for b, data in tqdm(enumerate(dataloader), total=len(dataloader), desc='Sampling'):
atom_mask = data['atom_mask']
fragment_mask = data['fragment_mask']
# Save molecules without pockets
if '.' in self.train_data_prefix:
atom_mask = data['atom_mask'] - data['pocket_mask']
fragment_mask = data['fragment_only_mask']
true_molecules_batch = build_molecules(
data['one_hot'],
data['positions'],
atom_mask,
is_geom=self.is_geom,
)
true_fragments_batch = build_molecules(
data['one_hot'],
data['positions'],
fragment_mask,
is_geom=self.is_geom,
)
for sample_idx in tqdm(range(self.n_stability_samples)):
try:
chain_batch, node_mask = self.sample_chain(data, keep_frames=self.FRAMES)
except utils.FoundNaNException as e:
for idx in e.x_h_nan_idx:
smiles = data['name'][idx]
print(f'FoundNaNException: [xh], e={self.current_epoch}, b={b}, i={idx}: {smiles}')
for idx in e.only_x_nan_idx:
smiles = data['name'][idx]
print(f'FoundNaNException: [x ], e={self.current_epoch}, b={b}, i={idx}: {smiles}')
for idx in e.only_h_nan_idx:
smiles = data['name'][idx]
print(f'FoundNaNException: [ h], e={self.current_epoch}, b={b}, i={idx}: {smiles}')
continue
# Get final molecules from chains – for computing metrics
x, h = utils.split_features(
z=chain_batch[0],
n_dims=self.n_dims,
num_classes=self.num_classes,
include_charges=self.include_charges,
)
# Save molecules without pockets
if '.' in self.train_data_prefix:
node_mask = node_mask - data['pocket_mask']
one_hot = h['categorical']
pred_molecules_batch = build_molecules(one_hot, x, node_mask, is_geom=self.is_geom)
# Adding only results for valid ground truth molecules
for pred_mol, true_mol, frag in zip(pred_molecules_batch, true_molecules_batch, true_fragments_batch):
if metrics.is_valid(true_mol):
pred_molecules.append(pred_mol)
true_molecules.append(true_mol)
true_fragments.append(frag)
# Generate animation – will always do it for molecules with idx 0, 110 and 360
if self.samples_dir is not None and sample_idx == 0:
self.generate_animation(chain_batch=chain_batch, node_mask=node_mask, batch_i=b)
# Our own & DeLinker metrics
our_metrics = metrics.compute_metrics(
pred_molecules=pred_molecules,
true_molecules=true_molecules
)
delinker_metrics = delinker.get_delinker_metrics(
pred_molecules=pred_molecules,
true_molecules=true_molecules,
true_fragments=true_fragments
)
return {
**our_metrics,
**delinker_metrics
}
def sample_chain(self, data, sample_fn=None, keep_frames=None):
if sample_fn is None:
linker_sizes = data['linker_mask'].sum(1).view(-1).int()
else:
linker_sizes = sample_fn(data)
if self.inpainting:
template_data = data
else:
template_data = create_templates_for_linker_generation(data, linker_sizes)
x = template_data['positions']
node_mask = template_data['atom_mask']
edge_mask = template_data['edge_mask']
h = template_data['one_hot']
anchors = template_data['anchors']
fragment_mask = template_data['fragment_mask']
linker_mask = template_data['linker_mask']
# Anchors and fragments labels are used as context
if self.anchors_context:
context = torch.cat([anchors, fragment_mask], dim=-1)
else:
context = fragment_mask
# Add information about pocket to the context
if '.' in self.train_data_prefix:
fragment_pocket_mask = fragment_mask
fragment_only_mask = data['fragment_only_mask']
pocket_only_mask = fragment_pocket_mask - fragment_only_mask
if self.anchors_context:
context = torch.cat([anchors, fragment_only_mask, pocket_only_mask], dim=-1)
else:
context = torch.cat([fragment_only_mask, pocket_only_mask], dim=-1)
# Removing COM of fragment from the atom coordinates
if self.inpainting:
center_of_mass_mask = node_mask
elif self.center_of_mass == 'fragments':
center_of_mass_mask = fragment_mask
elif self.center_of_mass == 'anchors':
center_of_mass_mask = anchors
else:
raise NotImplementedError(self.center_of_mass)
x = utils.remove_partial_mean_with_mask(x, node_mask, center_of_mass_mask)
chain = self.edm.sample_chain(
x=x,
h=h,
node_mask=node_mask,
edge_mask=edge_mask,
fragment_mask=fragment_mask,
linker_mask=linker_mask,
context=context,
keep_frames=keep_frames,
)
return chain, node_mask
def configure_optimizers(self):
return torch.optim.AdamW(self.edm.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12)
def compute_best_validation_metrics(self):
loss = self.metrics[f'validity_and_connectivity/val']
best_epoch = np.argmax(loss)
best_metrics = {
metric_name: metric_values[best_epoch]
for metric_name, metric_values in self.metrics.items()
if metric_name.endswith('/val')
}
return best_metrics, best_epoch
@staticmethod
def aggregate_metric(step_outputs, metric):
return torch.tensor([out[metric] for out in step_outputs]).mean()