heheyas
init
cfb7702
raw
history blame
No virus
4.45 kB
import pytorch_lightning as pl
import models
from systems.utils import parse_optimizer, parse_scheduler, update_module_step
from utils.mixins import SaverMixin
from utils.misc import config_to_primitive, get_rank
class BaseSystem(pl.LightningModule, SaverMixin):
"""
Two ways to print to console:
1. self.print: correctly handle progress bar
2. rank_zero_info: use the logging module
"""
def __init__(self, config):
super().__init__()
self.config = config
self.rank = get_rank()
self.prepare()
self.model = models.make(self.config.model.name, self.config.model)
def prepare(self):
pass
def forward(self, batch):
raise NotImplementedError
def C(self, value):
if isinstance(value, int) or isinstance(value, float):
pass
else:
value = config_to_primitive(value)
if not isinstance(value, list):
raise TypeError('Scalar specification only supports list, got', type(value))
if len(value) == 3:
value = [0] + value
assert len(value) == 4
start_step, start_value, end_value, end_step = value
if isinstance(end_step, int):
current_step = self.global_step
value = start_value + (end_value - start_value) * max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0)
elif isinstance(end_step, float):
current_step = self.current_epoch
value = start_value + (end_value - start_value) * max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0)
return value
def preprocess_data(self, batch, stage):
pass
"""
Implementing on_after_batch_transfer of DataModule does the same.
But on_after_batch_transfer does not support DP.
"""
def on_train_batch_start(self, batch, batch_idx, unused=0):
self.dataset = self.trainer.datamodule.train_dataloader().dataset
self.preprocess_data(batch, 'train')
update_module_step(self.model, self.current_epoch, self.global_step)
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
self.dataset = self.trainer.datamodule.val_dataloader().dataset
self.preprocess_data(batch, 'validation')
update_module_step(self.model, self.current_epoch, self.global_step)
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
self.dataset = self.trainer.datamodule.test_dataloader().dataset
self.preprocess_data(batch, 'test')
update_module_step(self.model, self.current_epoch, self.global_step)
def on_predict_batch_start(self, batch, batch_idx, dataloader_idx):
self.dataset = self.trainer.datamodule.predict_dataloader().dataset
self.preprocess_data(batch, 'predict')
update_module_step(self.model, self.current_epoch, self.global_step)
def training_step(self, batch, batch_idx):
raise NotImplementedError
"""
# aggregate outputs from different devices (DP)
def training_step_end(self, out):
pass
"""
"""
# aggregate outputs from different iterations
def training_epoch_end(self, out):
pass
"""
def validation_step(self, batch, batch_idx):
raise NotImplementedError
"""
# aggregate outputs from different devices when using DP
def validation_step_end(self, out):
pass
"""
def validation_epoch_end(self, out):
"""
Gather metrics from all devices, compute mean.
Purge repeated results using data index.
"""
raise NotImplementedError
def test_step(self, batch, batch_idx):
raise NotImplementedError
def test_epoch_end(self, out):
"""
Gather metrics from all devices, compute mean.
Purge repeated results using data index.
"""
raise NotImplementedError
def export(self):
raise NotImplementedError
def configure_optimizers(self):
optim = parse_optimizer(self.config.system.optimizer, self.model)
ret = {
'optimizer': optim,
}
if 'scheduler' in self.config.system:
ret.update({
'lr_scheduler': parse_scheduler(self.config.system.scheduler, optim),
})
return ret