|
|
from typing import Optional |
|
|
from pathlib import Path |
|
|
from functools import partial |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
import pytorch_lightning as pl |
|
|
from torch_scatter import scatter_mean |
|
|
|
|
|
from src.model.gvp import GVP, GVPModel, LayerNorm, GVPConvLayer |
|
|
from src.model.gvp_transformer import GVPTransformerModel, GVPTransformerLayer |
|
|
from src.constants import aa_decoder, residue_bond_encoder |
|
|
from src.data.dataset import ProcessedLigandPocketDataset |
|
|
import src.utils as utils |
|
|
|
|
|
|
|
|
class SizeModel(pl.LightningModule): |
|
|
def __init__( |
|
|
self, |
|
|
max_size, |
|
|
pocket_representation, |
|
|
train_params, |
|
|
loss_params, |
|
|
eval_params, |
|
|
predictor_params, |
|
|
): |
|
|
super(SizeModel, self).__init__() |
|
|
self.save_hyperparameters() |
|
|
|
|
|
assert pocket_representation == "CA+" |
|
|
self.pocket_representation = pocket_representation |
|
|
|
|
|
self.type = loss_params.type |
|
|
assert self.type in {'classifier', 'ordinal', 'regression'} |
|
|
|
|
|
self.train_dataset = None |
|
|
self.val_dataset = None |
|
|
self.test_dataset = None |
|
|
self.data_transform = None |
|
|
|
|
|
|
|
|
self.datadir = train_params.datadir |
|
|
self.batch_size = train_params.batch_size |
|
|
self.lr = train_params.lr |
|
|
self.num_workers = train_params.num_workers |
|
|
self.clip_grad = train_params.clip_grad |
|
|
if self.clip_grad: |
|
|
self.gradnorm_queue = utils.Queue() |
|
|
|
|
|
self.gradnorm_queue.add(3000) |
|
|
|
|
|
|
|
|
self.aa_decoder = aa_decoder |
|
|
self.residue_bond_encoder = residue_bond_encoder |
|
|
|
|
|
|
|
|
self.edge_cutoff = predictor_params.edge_cutoff |
|
|
self.add_nma_feat = predictor_params.normal_modes |
|
|
self.max_size = max_size |
|
|
self.n_classes = max_size if self.type == 'ordinal' else max_size + 1 |
|
|
backbone = predictor_params.backbone |
|
|
model_params = getattr(predictor_params, backbone + '_params') |
|
|
|
|
|
self.residue_nf = (len(self.aa_decoder), 0) |
|
|
if self.add_nma_feat: |
|
|
self.residue_nf = (self.residue_nf[0], self.residue_nf[1] + 5) |
|
|
|
|
|
out_nf = 1 if self.type == "regression" else self.n_classes |
|
|
|
|
|
if backbone == 'gvp_transformer': |
|
|
self.net = SizeGVPTransformer( |
|
|
node_in_dim=self.residue_nf, |
|
|
node_h_dim=model_params.node_h_dim, |
|
|
out_nf=out_nf, |
|
|
edge_in_nf=len(self.residue_bond_encoder), |
|
|
edge_h_dim=model_params.edge_h_dim, |
|
|
num_layers=model_params.n_layers, |
|
|
dk=model_params.dk, |
|
|
dv=model_params.dv, |
|
|
de=model_params.de, |
|
|
db=model_params.db, |
|
|
dy=model_params.dy, |
|
|
attn_heads=model_params.attn_heads, |
|
|
n_feedforward=model_params.n_feedforward, |
|
|
drop_rate=model_params.dropout, |
|
|
reflection_equiv=model_params.reflection_equivariant, |
|
|
d_max=model_params.d_max, |
|
|
num_rbf=model_params.num_rbf, |
|
|
vector_gate=model_params.vector_gate, |
|
|
attention=model_params.attention, |
|
|
) |
|
|
|
|
|
elif backbone == 'gvp_gnn': |
|
|
self.net = SizeGVPModel( |
|
|
node_in_dim=self.residue_nf, |
|
|
node_h_dim=model_params.node_h_dim, |
|
|
out_nf=out_nf, |
|
|
edge_in_nf=len(self.residue_bond_encoder), |
|
|
edge_h_dim=model_params.edge_h_dim, |
|
|
num_layers=model_params.n_layers, |
|
|
drop_rate=model_params.dropout, |
|
|
vector_gate=model_params.vector_gate, |
|
|
reflection_equiv=model_params.reflection_equivariant, |
|
|
d_max=model_params.d_max, |
|
|
num_rbf=model_params.num_rbf, |
|
|
) |
|
|
|
|
|
else: |
|
|
raise NotImplementedError(f"{backbone} is not available") |
|
|
|
|
|
def configure_optimizers(self): |
|
|
return torch.optim.AdamW(self.parameters(), lr=self.lr, |
|
|
amsgrad=True, weight_decay=1e-12) |
|
|
|
|
|
def setup(self, stage: Optional[str] = None): |
|
|
|
|
|
if stage == 'fit': |
|
|
self.train_dataset = ProcessedLigandPocketDataset( |
|
|
Path(self.datadir, 'train.pt'), |
|
|
ligand_transform=None, catch_errors=True) |
|
|
|
|
|
self.val_dataset = ProcessedLigandPocketDataset( |
|
|
Path(self.datadir, 'val.pt'), ligand_transform=None) |
|
|
elif stage == 'test': |
|
|
self.test_dataset = ProcessedLigandPocketDataset( |
|
|
Path(self.datadir, 'test.pt'), ligand_transform=None) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def train_dataloader(self): |
|
|
return DataLoader(self.train_dataset, self.batch_size, shuffle=True, |
|
|
num_workers=self.num_workers, |
|
|
|
|
|
collate_fn=partial(self.train_dataset.collate_fn, ligand_transform=self.data_transform), |
|
|
pin_memory=True) |
|
|
|
|
|
def val_dataloader(self): |
|
|
return DataLoader(self.val_dataset, self.batch_size, |
|
|
shuffle=False, num_workers=self.num_workers, |
|
|
collate_fn=self.val_dataset.collate_fn, |
|
|
pin_memory=True) |
|
|
|
|
|
def test_dataloader(self): |
|
|
return DataLoader(self.test_dataset, self.batch_size, shuffle=False, |
|
|
num_workers=self.num_workers, |
|
|
collate_fn=self.test_dataset.collate_fn, |
|
|
pin_memory=True) |
|
|
|
|
|
def forward(self, pocket): |
|
|
|
|
|
|
|
|
x, h, mask = pocket['x'], pocket['one_hot'], pocket['mask'] |
|
|
|
|
|
edges = None |
|
|
if 'bonds' in pocket: |
|
|
edges = (pocket['bonds'], pocket['bond_one_hot']) |
|
|
|
|
|
v = None |
|
|
if self.add_nma_feat: |
|
|
v = pocket['nma_vec'] |
|
|
|
|
|
if edges is not None: |
|
|
|
|
|
edge_indices = torch.cat( |
|
|
[edges[0], edges[0].flip(dims=[0])], dim=1) |
|
|
edge_types = torch.cat([edges[1], edges[1]], dim=0) |
|
|
|
|
|
edges, edge_feat = self.get_edges( |
|
|
mask, x, bond_inds=edge_indices, bond_feat=edge_types) |
|
|
|
|
|
assert torch.all(mask[edges[0]] == mask[edges[1]]) |
|
|
|
|
|
out = self.net(h, x, edges, v=v, batch_mask=mask, edge_attr=edge_feat) |
|
|
|
|
|
if torch.any(torch.isnan(out)): |
|
|
|
|
|
|
|
|
if self.training: |
|
|
print("NaN detected in network output") |
|
|
out[torch.isnan(out)] = 0.0 |
|
|
else: |
|
|
raise ValueError("NaN detected in network output") |
|
|
|
|
|
return out |
|
|
|
|
|
def get_edges(self, batch_mask, coord, bond_inds=None, bond_feat=None, self_edges=False): |
|
|
|
|
|
|
|
|
adj = batch_mask[:, None] == batch_mask[None, :] |
|
|
|
|
|
if self.edge_cutoff is not None: |
|
|
adj = adj & (torch.cdist(coord, coord) <= self.edge_cutoff) |
|
|
|
|
|
|
|
|
adj[bond_inds[0], bond_inds[1]] = True |
|
|
|
|
|
if not self_edges: |
|
|
adj = adj ^ torch.eye(*adj.size(), out=torch.empty_like(adj)) |
|
|
|
|
|
|
|
|
nobond_onehot = F.one_hot(torch.tensor( |
|
|
self.residue_bond_encoder['NOBOND'], device=bond_feat.device), |
|
|
num_classes=len(self.residue_bond_encoder)).float() |
|
|
|
|
|
|
|
|
feat = nobond_onehot.repeat(*adj.shape, 1) |
|
|
feat[bond_inds[0], bond_inds[1]] = bond_feat |
|
|
|
|
|
|
|
|
edges = torch.stack(torch.where(adj), dim=0) |
|
|
edge_feat = feat[edges[0], edges[1]] |
|
|
|
|
|
return edges, edge_feat |
|
|
|
|
|
def compute_loss(self, pred_logits, true_size): |
|
|
|
|
|
if self.type == "classifier": |
|
|
loss = F.cross_entropy(pred_logits, true_size) |
|
|
|
|
|
elif self.type == "ordinal": |
|
|
|
|
|
binary_labels = true_size.unsqueeze(1) > torch.arange(self.n_classes, device=true_size.device).unsqueeze(0) |
|
|
loss = F.binary_cross_entropy_with_logits(pred_logits, binary_labels.float()) |
|
|
|
|
|
elif self.type == 'regression': |
|
|
loss = F.mse_loss(pred_logits.squeeze(), true_size.float()) |
|
|
|
|
|
else: |
|
|
raise NotImplementedError() |
|
|
|
|
|
return loss |
|
|
|
|
|
def max_likelihood(self, pred_logits): |
|
|
|
|
|
if self.type == "classifier": |
|
|
pred = pred_logits.argmax(dim=-1) |
|
|
|
|
|
elif self.type == "ordinal": |
|
|
|
|
|
|
|
|
prop_greater = pred_logits.sigmoid() |
|
|
pred = torch.zeros((pred_logits.size(0), pred_logits.size(1) + 1), |
|
|
device=pred_logits.device) |
|
|
pred[:, 0] = 1 - prop_greater[:, 0] |
|
|
pred[:, 1:-1] = prop_greater[:, :-1] - prop_greater[:, 1:] |
|
|
pred[:, -1] = prop_greater[:, -1] |
|
|
pred = pred.argmax(dim=-1) |
|
|
|
|
|
elif self.type == 'regression': |
|
|
pred = torch.clip(torch.round(pred_logits), |
|
|
min=0, max=self.max_size) |
|
|
pred = pred.squeeze() |
|
|
|
|
|
else: |
|
|
raise NotImplementedError() |
|
|
|
|
|
return pred |
|
|
|
|
|
def log_metrics(self, metrics_dict, split, batch_size=None, **kwargs): |
|
|
for m, value in metrics_dict.items(): |
|
|
self.log(f'{m}/{split}', value, batch_size=batch_size, **kwargs) |
|
|
|
|
|
def compute_metrics(self, pred_logits, target): |
|
|
|
|
|
pred = self.max_likelihood(pred_logits) |
|
|
|
|
|
accuracy = (pred == target).sum() / len(target) |
|
|
mse = torch.mean((target - pred).float()**2) |
|
|
|
|
|
acc_window3 = (torch.abs(target - pred) <= 1).sum() / len(target) |
|
|
acc_window5 = (torch.abs(target - pred) <= 2).sum() / len(target) |
|
|
|
|
|
return {'accuracy': accuracy, |
|
|
'mse': mse, |
|
|
'accuracy_window3': acc_window3, |
|
|
'accuracy_window5': acc_window5} |
|
|
|
|
|
def training_step(self, data, *args): |
|
|
|
|
|
ligand, pocket = data['ligand'], data['pocket'] |
|
|
|
|
|
try: |
|
|
pred_logits = self.forward(pocket) |
|
|
true_size = ligand['size'] |
|
|
|
|
|
except RuntimeError as e: |
|
|
|
|
|
if self.trainer.num_devices < 2 and 'out of memory' in str(e): |
|
|
print('WARNING: ran out of memory, skipping to the next batch') |
|
|
return None |
|
|
else: |
|
|
raise e |
|
|
loss = self.compute_loss(pred_logits, true_size) |
|
|
|
|
|
|
|
|
metrics = self.compute_metrics(pred_logits, true_size) |
|
|
self.log_metrics({'loss': loss, **metrics}, 'train', |
|
|
batch_size=len(true_size), prog_bar=False) |
|
|
|
|
|
return loss |
|
|
|
|
|
def validation_step(self, data, *args): |
|
|
ligand, pocket = data['ligand'], data['pocket'] |
|
|
|
|
|
pred_logits = self.forward(pocket) |
|
|
true_size = ligand['size'] |
|
|
|
|
|
loss = self.compute_loss(pred_logits, true_size) |
|
|
|
|
|
|
|
|
metrics = self.compute_metrics(pred_logits, true_size) |
|
|
self.log_metrics({'loss': loss, **metrics}, 'val', batch_size=len(true_size)) |
|
|
|
|
|
return loss |
|
|
|
|
|
def configure_gradient_clipping(self, optimizer, optimizer_idx, |
|
|
gradient_clip_val, gradient_clip_algorithm): |
|
|
|
|
|
if not self.clip_grad: |
|
|
return |
|
|
|
|
|
|
|
|
max_grad_norm = 1.5 * self.gradnorm_queue.mean() + \ |
|
|
2 * self.gradnorm_queue.std() |
|
|
|
|
|
|
|
|
params = [p for g in optimizer.param_groups for p in g['params']] |
|
|
grad_norm = utils.get_grad_norm(params) |
|
|
|
|
|
|
|
|
self.clip_gradients(optimizer, gradient_clip_val=max_grad_norm, |
|
|
gradient_clip_algorithm='norm') |
|
|
|
|
|
if float(grad_norm) > max_grad_norm: |
|
|
self.gradnorm_queue.add(float(max_grad_norm)) |
|
|
else: |
|
|
self.gradnorm_queue.add(float(grad_norm)) |
|
|
|
|
|
if float(grad_norm) > max_grad_norm: |
|
|
print(f'Clipped gradient with value {grad_norm:.1f} ' |
|
|
f'while allowed {max_grad_norm:.1f}') |
|
|
|
|
|
|
|
|
class SizeGVPTransformer(GVPTransformerModel): |
|
|
""" |
|
|
GVP-Transformer model |
|
|
|
|
|
:param node_in_dim: node dimension in input graph, scalars or tuple (scalars, vectors) |
|
|
:param node_h_dim: node dimensions to use in GVP-GNN layers, tuple (s, V) |
|
|
:param out_nf: node dimensions of output feature, tuple (s, V) |
|
|
:param edge_in_nf: edge dimension in input graph (scalars) |
|
|
:param edge_h_dim: edge dimensions to embed to before use in GVP-GNN layers, |
|
|
tuple (s, V) |
|
|
:param num_layers: number of GVP-GNN layers |
|
|
:param drop_rate: rate to use in all dropout layers |
|
|
:param reflection_equiv: bool, use reflection-sensitive feature based on the |
|
|
cross product if False |
|
|
:param d_max: |
|
|
:param num_rbf: |
|
|
:param vector_gate: use vector gates in all GVPs |
|
|
:param attention: can be used to turn off the attention mechanism |
|
|
""" |
|
|
def __init__(self, node_in_dim, node_h_dim, out_nf, edge_in_nf, |
|
|
edge_h_dim, num_layers, dk, dv, de, db, dy, |
|
|
attn_heads, n_feedforward, drop_rate, reflection_equiv=True, |
|
|
d_max=20.0, num_rbf=16, vector_gate=False, attention=True): |
|
|
|
|
|
super(GVPTransformerModel, self).__init__() |
|
|
|
|
|
self.reflection_equiv = reflection_equiv |
|
|
self.d_max = d_max |
|
|
self.num_rbf = num_rbf |
|
|
|
|
|
|
|
|
if not isinstance(node_in_dim, tuple): |
|
|
node_in_dim = (node_in_dim, 0) |
|
|
|
|
|
edge_in_dim = (edge_in_nf + 2 * node_in_dim[0] + self.num_rbf, 1) |
|
|
if not self.reflection_equiv: |
|
|
edge_in_dim = (edge_in_dim[0], edge_in_dim[1] + 1) |
|
|
|
|
|
self.W_v = GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=vector_gate) |
|
|
self.W_e = GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=vector_gate) |
|
|
|
|
|
self.dy = dy |
|
|
self.layers = nn.ModuleList( |
|
|
GVPTransformerLayer(node_h_dim, edge_h_dim, dy, dk, dv, de, db, |
|
|
attn_heads, n_feedforward=n_feedforward, |
|
|
drop_rate=drop_rate, vector_gate=vector_gate, |
|
|
activations=(F.relu, None), attention=attention) |
|
|
for _ in range(num_layers)) |
|
|
|
|
|
self.W_y_out = GVP(dy, (out_nf, 0), activations=(None, None), vector_gate=vector_gate) |
|
|
|
|
|
def forward(self, h, x, edge_index, v=None, batch_mask=None, edge_attr=None): |
|
|
|
|
|
bs = len(batch_mask.unique()) |
|
|
|
|
|
|
|
|
h_v = h if v is None else (h, v) |
|
|
h_e = self.edge_features(h, x, edge_index, batch_mask, edge_attr) |
|
|
|
|
|
h_v = self.W_v(h_v) |
|
|
h_e = self.W_e(h_e) |
|
|
h_y = (torch.zeros(bs, self.dy[0], device=h.device), |
|
|
torch.zeros(bs, self.dy[1], 3, device=h.device)) |
|
|
|
|
|
for layer in self.layers: |
|
|
h_v, h_e, h_y = layer(h_v, edge_index, batch_mask, h_e, h_y) |
|
|
|
|
|
return self.W_y_out(h_y) |
|
|
|
|
|
|
|
|
class SizeGVPModel(GVPModel): |
|
|
""" |
|
|
GVP-GNN model |
|
|
inspired by: https://github.com/drorlab/gvp-pytorch/blob/main/gvp/models.py |
|
|
and: https://github.com/drorlab/gvp-pytorch/blob/82af6b22eaf8311c15733117b0071408d24ed877/gvp/atom3d.py#L115 |
|
|
|
|
|
:param node_in_dim: node dimension in input graph, scalars or tuple (scalars, vectors) |
|
|
:param node_h_dim: node dimensions to use in GVP-GNN layers, tuple (s, V) |
|
|
:param out_nf: node dimensions of output feature, tuple (s, V) |
|
|
:param edge_in_nf: edge dimension in input graph (scalars) |
|
|
:param edge_h_dim: edge dimensions to embed to before use in GVP-GNN layers, |
|
|
tuple (s, V) |
|
|
:param num_layers: number of GVP-GNN layers |
|
|
:param drop_rate: rate to use in all dropout layers |
|
|
:param vector_gate: use vector gates in all GVPs |
|
|
:param reflection_equiv: bool, use reflection-sensitive feature based on the |
|
|
cross product if False |
|
|
:param d_max: |
|
|
:param num_rbf: |
|
|
:param update_edge_attr: bool, update edge attributes at each layer in a |
|
|
learnable way |
|
|
""" |
|
|
def __init__(self, node_in_dim, node_h_dim, out_nf, |
|
|
edge_in_nf, edge_h_dim, num_layers=3, drop_rate=0.1, |
|
|
vector_gate=False, reflection_equiv=True, d_max=20.0, |
|
|
num_rbf=16): |
|
|
|
|
|
super(GVPModel, self).__init__() |
|
|
|
|
|
self.reflection_equiv = reflection_equiv |
|
|
self.d_max = d_max |
|
|
self.num_rbf = num_rbf |
|
|
|
|
|
if not isinstance(node_in_dim, tuple): |
|
|
node_in_dim = (node_in_dim, 0) |
|
|
|
|
|
edge_in_dim = (edge_in_nf + 2 * node_in_dim[0] + self.num_rbf, 1) |
|
|
if not self.reflection_equiv: |
|
|
edge_in_dim = (edge_in_dim[0], edge_in_dim[1] + 1) |
|
|
|
|
|
self.W_v = nn.Sequential( |
|
|
LayerNorm(node_in_dim, learnable_vector_weight=True), |
|
|
GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=vector_gate), |
|
|
) |
|
|
self.W_e = nn.Sequential( |
|
|
LayerNorm(edge_in_dim, learnable_vector_weight=True), |
|
|
GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=vector_gate), |
|
|
) |
|
|
|
|
|
self.layers = nn.ModuleList( |
|
|
GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate, |
|
|
update_edge_attr=True, activations=(F.relu, None), |
|
|
vector_gate=vector_gate, ln_vector_weight=True) |
|
|
for _ in range(num_layers)) |
|
|
|
|
|
self.W_y_out = nn.Sequential( |
|
|
|
|
|
|
|
|
LayerNorm(node_h_dim, learnable_vector_weight=True), |
|
|
GVP(node_h_dim, (out_nf, 0), activations=(None, None), vector_gate=vector_gate), |
|
|
) |
|
|
|
|
|
def forward(self, h, x, edge_index, v=None, batch_mask=None, edge_attr=None): |
|
|
|
|
|
batch_size = len(torch.unique(batch_mask)) |
|
|
|
|
|
h_v = h if v is None else (h, v) |
|
|
h_e = self.edge_features(h, x, edge_index, batch_mask, edge_attr) |
|
|
|
|
|
h_v = self.W_v(h_v) |
|
|
h_e = self.W_e(h_e) |
|
|
|
|
|
for layer in self.layers: |
|
|
h_v, h_e = layer(h_v, edge_index, edge_attr=h_e) |
|
|
|
|
|
|
|
|
sm = scatter_mean(h_v[0], batch_mask, dim=0, dim_size=batch_size) |
|
|
vm = scatter_mean(h_v[1], batch_mask, dim=0, dim_size=batch_size) |
|
|
|
|
|
return self.W_y_out((sm, vm)) |
|
|
|