import os import pytorch_lightning as pl import torch from typing import Any from modules.params.diffusion.inference_params import InferenceParams from modules.loader.module_loader import GenericModuleLoader from modules.params.diffusion_trainer.params_streaming_diff_trainer import DiffusionTrainerParams class AbstractTrainer(pl.LightningModule): def __init__(self, inference_params: Any, diff_trainer_params: DiffusionTrainerParams, module_loader: GenericModuleLoader, ): super().__init__() self.inference_params = inference_params self.diff_trainer_params = diff_trainer_params self.module_loader = module_loader self.on_start_once_called = False self._setup_methods = [] module_loader( trainer=self, diff_trainer_params=diff_trainer_params) # ------ IMPLEMENTATION HOOKS ------- def post_init(self, batch): ''' Is called after LightningDataModule and LightningModule is created, but before any training/validation/prediction. First possible access to the 'trainer' object (e.g. to get 'device'). ''' def generate_output(self, batch, batch_idx, inference_params: InferenceParams): ''' Is called during validation to generate for each batch an output. Return the meta information about produced result (where result were stored). This is used for the metric evaluation. ''' # ------- HELPER FUNCTIONS ------- def _reset_random_generator(self): ''' Reset the random generator to the same seed across all workers. The generator is used only for inference. ''' if not hasattr(self, "random_generator"): self.random_generator = torch.Generator(device=self.device) # set seed according to 'seed_everything' in config seed = int(os.environ.get("PL_GLOBAL_SEED", 42)) else: seed = self.random_generator.initial_seed() self.random_generator.manual_seed(seed) # ----- PREDICT HOOKS ------ def on_predict_start(self): self.on_start() def predict_step(self, batch, batch_idx): self.on_inference_step(batch=batch, batch_idx=batch_idx) def on_predict_epoch_start(self): self.on_inference_epoch_start() # ----- CUSTOM HOOKS ----- # Global Hooks (Called by Training, Validation and Prediction) # abstract method def _on_start_once(self): ''' Will be called only once by on_start. Thus, it will be called by the first call of train,validation or prediction. ''' if self.on_start_once_called: return else: self.on_start_once_called = True self.post_init() def on_start(self): ''' Called at the beginning of training, validation and prediction. ''' self._on_start_once() # Inference Hooks (Called by Validation and Prediction) # ----- Inference Hooks (called by 'validation' and 'predict') ------ def on_inference_epoch_start(self): # reset seed at every inference self._reset_random_generator() def on_inference_step(self, batch, batch_idx): if self.inference_params.reset_seed_per_generation: self._reset_random_generator() self.generate_output( batch=batch, inference_params=self.inference_params, batch_idx=batch_idx)