import pytorch_lightning as pl import torch from src.const import ZINC_TRAIN_LINKER_ID2SIZE, ZINC_TRAIN_LINKER_SIZE2ID from src.linker_size import SizeGNN from src.egnn import coord2diff from src.datasets import ZincDataset, get_dataloader, collate_with_fragment_edges from typing import Dict, List, Optional from torch.nn.functional import cross_entropy, mse_loss, sigmoid from pdb import set_trace class SizeClassifier(pl.LightningModule): train_dataset = None val_dataset = None test_dataset = None metrics: Dict[str, List[float]] = {} def __init__( self, data_path, train_data_prefix, val_data_prefix, in_node_nf, hidden_nf, out_node_nf, n_layers, batch_size, lr, torch_device, normalization=None, loss_weights=None, min_linker_size=None, linker_size2id=ZINC_TRAIN_LINKER_SIZE2ID, linker_id2size=ZINC_TRAIN_LINKER_ID2SIZE, task='classification', ): super(SizeClassifier, self).__init__() self.save_hyperparameters() self.data_path = data_path self.train_data_prefix = train_data_prefix self.val_data_prefix = val_data_prefix self.min_linker_size = min_linker_size self.linker_size2id = linker_size2id self.linker_id2size = linker_id2size self.batch_size = batch_size self.lr = lr self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.loss_weights = None if loss_weights is None else torch.tensor(loss_weights, device=self.torch_device) self.in_node_nf = in_node_nf self.gnn = SizeGNN( in_node_nf=in_node_nf, hidden_nf=hidden_nf, out_node_nf=out_node_nf, n_layers=n_layers, device=self.torch_device, normalization=normalization, ) def setup(self, stage: Optional[str] = None): if stage == 'fit': self.train_dataset = ZincDataset( data_path=self.data_path, prefix=self.train_data_prefix, device=self.torch_device ) self.val_dataset = ZincDataset( data_path=self.data_path, prefix=self.val_data_prefix, device=self.torch_device ) elif stage == 'val': self.val_dataset = ZincDataset( data_path=self.data_path, prefix=self.val_data_prefix, device=self.torch_device ) else: raise NotImplementedError def train_dataloader(self): return get_dataloader(self.train_dataset, self.batch_size, collate_fn=collate_with_fragment_edges, shuffle=True) def val_dataloader(self): return get_dataloader(self.val_dataset, self.batch_size, collate_fn=collate_with_fragment_edges) def test_dataloader(self): return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges) def forward(self, data, return_loss=True, with_pocket=False, adjust_shape=False): h = data['one_hot'] x = data['positions'] fragment_mask = data['fragment_only_mask'] if with_pocket else data['fragment_mask'] linker_mask = data['linker_mask'] edge_mask = data['edge_mask'] edges = data['edges'] # Considering only fragments x = x * fragment_mask h = h * fragment_mask if h.shape[-1] != self.in_node_nf and adjust_shape: assert torch.allclose(h[..., -1], torch.zeros_like(h[..., -1])) h = h[..., :-1] # Reshaping bs, n_nodes = x.shape[0], x.shape[1] fragment_mask = fragment_mask.view(bs * n_nodes, 1) x = x.view(bs * n_nodes, -1) h = h.view(bs * n_nodes, -1) # Prediction distances, _ = coord2diff(x, edges) distance_edge_mask = (edge_mask.bool() & (distances < 6)).long() output = self.gnn.forward(h, edges, distances, fragment_mask, distance_edge_mask) output = output.view(bs, n_nodes, -1).mean(1) if return_loss: true = self.get_true_labels(linker_mask) loss = cross_entropy(output, true, weight=self.loss_weights) else: loss = None return output, loss def get_true_labels(self, linker_mask): labels = [] sizes = linker_mask.squeeze().sum(-1).long().detach().cpu().numpy() for size in sizes: label = self.linker_size2id.get(size) if label is None: label = self.linker_size2id[max(self.linker_id2size)] labels.append(label) labels = torch.tensor(labels, device=linker_mask.device, dtype=torch.long) return labels def training_step(self, data, *args): _, loss = self.forward(data) return {'loss': loss} def validation_step(self, data, *args): _, loss = self.forward(data) return {'loss': loss} def test_step(self, data, *args): loss = self.forward(data) return {'loss': loss} def training_epoch_end(self, training_step_outputs): for metric in training_step_outputs[0].keys(): avg_metric = self.aggregate_metric(training_step_outputs, metric) self.metrics.setdefault(f'{metric}/train', []).append(avg_metric) self.log(f'{metric}/train', avg_metric, prog_bar=True) def validation_epoch_end(self, validation_step_outputs): for metric in validation_step_outputs[0].keys(): avg_metric = self.aggregate_metric(validation_step_outputs, metric) self.metrics.setdefault(f'{metric}/val', []).append(avg_metric) self.log(f'{metric}/val', avg_metric, prog_bar=True) correct = 0 total = 0 for data in self.val_dataloader(): output, _ = self.forward(data) pred = output.argmax(dim=-1) true = self.get_true_labels(data['linker_mask']) correct += (pred == true).sum() total += len(pred) accuracy = correct / total self.metrics.setdefault(f'accuracy/val', []).append(accuracy) self.log(f'accuracy/val', accuracy, prog_bar=True) def configure_optimizers(self): return torch.optim.AdamW(self.gnn.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12) @staticmethod def aggregate_metric(step_outputs, metric): return torch.tensor([out[metric] for out in step_outputs]).mean() class SizeOrdinalClassifier(pl.LightningModule): train_dataset = None val_dataset = None test_dataset = None metrics: Dict[str, List[float]] = {} def __init__( self, data_path, train_data_prefix, val_data_prefix, in_node_nf, hidden_nf, out_node_nf, n_layers, batch_size, lr, torch_device, normalization=None, min_linker_size=None, linker_size2id=ZINC_TRAIN_LINKER_SIZE2ID, linker_id2size=ZINC_TRAIN_LINKER_ID2SIZE, task='ordinal', ): super(SizeOrdinalClassifier, self).__init__() self.save_hyperparameters() self.data_path = data_path self.train_data_prefix = train_data_prefix self.val_data_prefix = val_data_prefix self.min_linker_size = min_linker_size self.batch_size = batch_size self.lr = lr self.torch_device = torch_device self.linker_size2id = linker_size2id self.linker_id2size = linker_id2size self.gnn = SizeGNN( in_node_nf=in_node_nf, hidden_nf=hidden_nf, out_node_nf=out_node_nf, n_layers=n_layers, device=torch_device, normalization=normalization, ) def setup(self, stage: Optional[str] = None): if stage == 'fit': self.train_dataset = ZincDataset( data_path=self.data_path, prefix=self.train_data_prefix, device=self.torch_device ) self.val_dataset = ZincDataset( data_path=self.data_path, prefix=self.val_data_prefix, device=self.torch_device ) elif stage == 'val': self.val_dataset = ZincDataset( data_path=self.data_path, prefix=self.val_data_prefix, device=self.torch_device ) else: raise NotImplementedError def train_dataloader(self): return get_dataloader(self.train_dataset, self.batch_size, collate_fn=collate_with_fragment_edges, shuffle=True) def val_dataloader(self): return get_dataloader(self.val_dataset, self.batch_size, collate_fn=collate_with_fragment_edges) def test_dataloader(self): return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges) def forward(self, data): h = data['one_hot'] x = data['positions'] fragment_mask = data['fragment_mask'] linker_mask = data['linker_mask'] edge_mask = data['edge_mask'] edges = data['edges'] # Considering only fragments x = x * fragment_mask h = h * fragment_mask # Reshaping bs, n_nodes = x.shape[0], x.shape[1] fragment_mask = fragment_mask.view(bs * n_nodes, 1) x = x.view(bs * n_nodes, -1) h = h.view(bs * n_nodes, -1) # Prediction distances, _ = coord2diff(x, edges) distance_edge_mask = (edge_mask.bool() & (distances < 6)).long() output = self.gnn.forward(h, edges, distances, fragment_mask, distance_edge_mask) output = output.view(bs, n_nodes, -1).mean(1) output = sigmoid(output) true = self.get_true_labels(linker_mask) loss = self.ordinal_loss(output, true) return output, loss def ordinal_loss(self, pred, true): target = torch.zeros_like(pred, device=self.torch_device) for i, label in enumerate(true): target[i, 0:label + 1] = 1 return mse_loss(pred, target, reduction='none').sum(1).mean() def get_true_labels(self, linker_mask): labels = [] sizes = linker_mask.squeeze().sum(-1).long().detach().cpu().numpy() for size in sizes: label = self.linker_size2id.get(size) if label is None: label = self.linker_size2id[max(self.linker_id2size)] labels.append(label) labels = torch.tensor(labels, device=linker_mask.device, dtype=torch.long) return labels @staticmethod def prediction2label(pred): return torch.cumprod(pred > 0.5, dim=1).sum(dim=1) - 1 def training_step(self, data, *args): _, loss = self.forward(data) return {'loss': loss} def validation_step(self, data, *args): _, loss = self.forward(data) return {'loss': loss} def test_step(self, data, *args): loss = self.forward(data) return {'loss': loss} def training_epoch_end(self, training_step_outputs): for metric in training_step_outputs[0].keys(): avg_metric = self.aggregate_metric(training_step_outputs, metric) self.metrics.setdefault(f'{metric}/train', []).append(avg_metric) self.log(f'{metric}/train', avg_metric, prog_bar=True) def validation_epoch_end(self, validation_step_outputs): for metric in validation_step_outputs[0].keys(): avg_metric = self.aggregate_metric(validation_step_outputs, metric) self.metrics.setdefault(f'{metric}/val', []).append(avg_metric) self.log(f'{metric}/val', avg_metric, prog_bar=True) correct = 0 total = 0 for data in self.val_dataloader(): output, _ = self.forward(data) pred = self.prediction2label(output) true = self.get_true_labels(data['linker_mask']) correct += (pred == true).sum() total += len(pred) accuracy = correct / total self.metrics.setdefault(f'accuracy/val', []).append(accuracy) self.log(f'accuracy/val', accuracy, prog_bar=True) def configure_optimizers(self): return torch.optim.AdamW(self.gnn.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12) @staticmethod def aggregate_metric(step_outputs, metric): return torch.tensor([out[metric] for out in step_outputs]).mean() class SizeRegressor(pl.LightningModule): train_dataset = None val_dataset = None test_dataset = None metrics: Dict[str, List[float]] = {} def __init__( self, data_path, train_data_prefix, val_data_prefix, in_node_nf, hidden_nf, n_layers, batch_size, lr, torch_device, normalization=None, task='regression', ): super(SizeRegressor, self).__init__() self.save_hyperparameters() self.data_path = data_path self.train_data_prefix = train_data_prefix self.val_data_prefix = val_data_prefix self.batch_size = batch_size self.lr = lr self.torch_device = torch_device self.gnn = SizeGNN( in_node_nf=in_node_nf, hidden_nf=hidden_nf, out_node_nf=1, n_layers=n_layers, device=torch_device, normalization=normalization, ) def setup(self, stage: Optional[str] = None): if stage == 'fit': self.train_dataset = ZincDataset( data_path=self.data_path, prefix=self.train_data_prefix, device=self.torch_device ) self.val_dataset = ZincDataset( data_path=self.data_path, prefix=self.val_data_prefix, device=self.torch_device ) elif stage == 'val': self.val_dataset = ZincDataset( data_path=self.data_path, prefix=self.val_data_prefix, device=self.torch_device ) else: raise NotImplementedError def train_dataloader(self): return get_dataloader(self.train_dataset, self.batch_size, collate_fn=collate_with_fragment_edges, shuffle=True) def val_dataloader(self): return get_dataloader(self.val_dataset, self.batch_size, collate_fn=collate_with_fragment_edges) def test_dataloader(self): return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges) def forward(self, data): h = data['one_hot'] x = data['positions'] fragment_mask = data['fragment_mask'] linker_mask = data['linker_mask'] edge_mask = data['edge_mask'] edges = data['edges'] # Considering only fragments x = x * fragment_mask h = h * fragment_mask # Reshaping bs, n_nodes = x.shape[0], x.shape[1] fragment_mask = fragment_mask.view(bs * n_nodes, 1) x = x.view(bs * n_nodes, -1) h = h.view(bs * n_nodes, -1) # Prediction distances, _ = coord2diff(x, edges) distance_edge_mask = (edge_mask.bool() & (distances < 6)).long() output = self.gnn.forward(h, edges, distances, fragment_mask, distance_edge_mask) output = output.view(bs, n_nodes, -1).mean(1).squeeze() true = linker_mask.squeeze().sum(-1).float() loss = mse_loss(output, true) return output, loss def training_step(self, data, *args): _, loss = self.forward(data) return {'loss': loss} def validation_step(self, data, *args): _, loss = self.forward(data) return {'loss': loss} def test_step(self, data, *args): loss = self.forward(data) return {'loss': loss} def training_epoch_end(self, training_step_outputs): for metric in training_step_outputs[0].keys(): avg_metric = self.aggregate_metric(training_step_outputs, metric) self.metrics.setdefault(f'{metric}/train', []).append(avg_metric) self.log(f'{metric}/train', avg_metric, prog_bar=True) def validation_epoch_end(self, validation_step_outputs): for metric in validation_step_outputs[0].keys(): avg_metric = self.aggregate_metric(validation_step_outputs, metric) self.metrics.setdefault(f'{metric}/val', []).append(avg_metric) self.log(f'{metric}/val', avg_metric, prog_bar=True) correct = 0 total = 0 for data in self.val_dataloader(): output, _ = self.forward(data) pred = torch.round(output).long() true = data['linker_mask'].squeeze().sum(-1).long() correct += (pred == true).sum() total += len(pred) accuracy = correct / total self.metrics.setdefault(f'accuracy/val', []).append(accuracy) self.log(f'accuracy/val', accuracy, prog_bar=True) def configure_optimizers(self): return torch.optim.AdamW(self.gnn.parameters(), lr=self.lr, amsgrad=True, weight_decay=1e-12) @staticmethod def aggregate_metric(step_outputs, metric): return torch.tensor([out[metric] for out in step_outputs]).mean()