from rdkit import Chem import torch from torch import nn from pytorch_lightning import LightningModule import torchmetrics from fsr_fg_model import FsrFgModel from data import FsrFgDataModule from pytorch_lightning.cli import LightningCLI class FsrFgLightning(LightningModule): def __init__(self, fg_input_dim=2786, mfg_input_dim=2586, num_input_dim=208, enc_dec_dims=(500, 100), output_dims=(200, 100, 50), num_tasks=2, dropout=0.8, method='FGR', lr=1e-4, **kwargs): super(FsrFgLightning, self).__init__() self.save_hyperparameters('fg_input_dim', 'mfg_input_dim', 'num_input_dim', 'enc_dec_dims', 'output_dims', 'num_tasks', 'dropout', 'method', 'lr') self.net = FsrFgModel(fg_input_dim, mfg_input_dim, num_input_dim, enc_dec_dims, output_dims, num_tasks, dropout, method) self.lr = lr self.method = method self.criterion = nn.CrossEntropyLoss() self.recon_loss = nn.BCEWithLogitsLoss() self.softmax = nn.Softmax(dim=1) self.train_auc = torchmetrics.AUROC(num_classes=num_tasks) self.valid_auc = torchmetrics.AUROC(num_classes=num_tasks) self.test_auc = torchmetrics.AUROC(num_classes=num_tasks) def forward(self, fg, mfg, num_features): if self.method == 'FG': y_pred = self.net(fg=fg) elif self.method == 'MFG': y_pred = self.net(mfg=mfg) elif self.method == 'FGR': y_pred = self.net(fg=fg, mfg=mfg) else: y_pred = self.net(fg=fg, mfg=mfg, num_features=num_features) return y_pred def configure_optimizers(self): optimizer = torch.optim.AdamW(self.net.parameters(), lr=self.lr, weight_decay=0.3) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=1e-2, total_steps=self.trainer.estimated_stepping_batches) return [optimizer], [scheduler] def training_step(self, batch, batch_idx): fg, mfg, num_features, y = batch y_pred, recon = self(fg, mfg, num_features) if self.method == 'FG': loss_r_pre = 1e-4 * self.recon_loss(recon, fg) elif self.method == 'MFG': loss_r_pre = 1e-4 * self.recon_loss(recon, mfg) else: loss_r_pre = 1e-4 * self.recon_loss(recon, torch.cat([fg, mfg], dim=1)) loss = self.criterion(y_pred, y) + loss_r_pre self.train_auc(self.softmax(y_pred), y) self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=False, logger=True) self.log('train_auc', self.train_auc, on_epoch=True, on_step=False, prog_bar=True, logger=True) return loss def validation_step(self, batch, batch_idx): fg, mfg, num_features, y = batch y_pred, recon = self(fg, mfg, num_features) loss = self.criterion(y_pred, y) self.valid_auc(self.softmax(y_pred), y) self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('val_auc', self.valid_auc, on_step=False, on_epoch=True, prog_bar=True, logger=True) def test_step(self, batch, batch_idx): fg, mfg, num_features, y = batch y_pred, recon = self(fg, mfg, num_features) loss = self.criterion(y_pred, y) self.test_auc(self.softmax(y_pred), y) self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) self.log('test_auc', self.test_auc, on_step=False, on_epoch=True, prog_bar=True, logger=True) if __name__ == '__main__': cli = LightningCLI(model_class=FsrFgLightning, datamodule_class=FsrFgDataModule, save_config_callback=None, run=False) cli.trainer.fit(cli.model, cli.datamodule) cli.trainer.test(cli.model, cli.datamodule, ckpt_path='best')