import pytorch_lightning as pl import models from systems.utils import parse_optimizer, parse_scheduler, update_module_step from utils.mixins import SaverMixin from utils.misc import config_to_primitive, get_rank class BaseSystem(pl.LightningModule, SaverMixin): """ Two ways to print to console: 1. self.print: correctly handle progress bar 2. rank_zero_info: use the logging module """ def __init__(self, config): super().__init__() self.config = config self.rank = get_rank() self.prepare() self.model = models.make(self.config.model.name, self.config.model) def prepare(self): pass def forward(self, batch): raise NotImplementedError def C(self, value): if isinstance(value, int) or isinstance(value, float): pass else: value = config_to_primitive(value) if not isinstance(value, list): raise TypeError('Scalar specification only supports list, got', type(value)) if len(value) == 3: value = [0] + value assert len(value) == 4 start_step, start_value, end_value, end_step = value if isinstance(end_step, int): current_step = self.global_step value = start_value + (end_value - start_value) * max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0) elif isinstance(end_step, float): current_step = self.current_epoch value = start_value + (end_value - start_value) * max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0) return value def preprocess_data(self, batch, stage): pass """ Implementing on_after_batch_transfer of DataModule does the same. But on_after_batch_transfer does not support DP. """ def on_train_batch_start(self, batch, batch_idx, unused=0): self.dataset = self.trainer.datamodule.train_dataloader().dataset self.preprocess_data(batch, 'train') update_module_step(self.model, self.current_epoch, self.global_step) def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): self.dataset = self.trainer.datamodule.val_dataloader().dataset self.preprocess_data(batch, 'validation') update_module_step(self.model, self.current_epoch, self.global_step) def on_test_batch_start(self, batch, batch_idx, dataloader_idx): self.dataset = self.trainer.datamodule.test_dataloader().dataset self.preprocess_data(batch, 'test') update_module_step(self.model, self.current_epoch, self.global_step) def on_predict_batch_start(self, batch, batch_idx, dataloader_idx): self.dataset = self.trainer.datamodule.predict_dataloader().dataset self.preprocess_data(batch, 'predict') update_module_step(self.model, self.current_epoch, self.global_step) def training_step(self, batch, batch_idx): raise NotImplementedError """ # aggregate outputs from different devices (DP) def training_step_end(self, out): pass """ """ # aggregate outputs from different iterations def training_epoch_end(self, out): pass """ def validation_step(self, batch, batch_idx): raise NotImplementedError """ # aggregate outputs from different devices when using DP def validation_step_end(self, out): pass """ def validation_epoch_end(self, out): """ Gather metrics from all devices, compute mean. Purge repeated results using data index. """ raise NotImplementedError def test_step(self, batch, batch_idx): raise NotImplementedError def test_epoch_end(self, out): """ Gather metrics from all devices, compute mean. Purge repeated results using data index. """ raise NotImplementedError def export(self): raise NotImplementedError def configure_optimizers(self): optim = parse_optimizer(self.config.system.optimizer, self.model) ret = { 'optimizer': optim, } if 'scheduler' in self.config.system: ret.update({ 'lr_scheduler': parse_scheduler(self.config.system.scheduler, optim), }) return ret