from typing import Any, Union import hydra import lightning as pl import torch from omegaconf import DictConfig from relik.retriever.common.model_inputs import ModelInputs class GoldenRetrieverPLModule(pl.LightningModule): def __init__( self, model: Union[torch.nn.Module, DictConfig], optimizer: Union[torch.optim.Optimizer, DictConfig], lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, DictConfig] = None, *args, **kwargs, ) -> None: super().__init__() self.save_hyperparameters(ignore=["model"]) if isinstance(model, DictConfig): self.model = hydra.utils.instantiate(model) else: self.model = model self.optimizer_config = optimizer self.lr_scheduler_config = lr_scheduler def forward(self, **kwargs) -> dict: """ Method for the forward pass. 'training_step', 'validation_step' and 'test_step' should call this method in order to compute the output predictions and the loss. Returns: output_dict: forward output containing the predictions (output logits ecc...) and the loss if any. """ return self.model(**kwargs) def training_step(self, batch: ModelInputs, batch_idx: int) -> torch.Tensor: forward_output = self.forward(**batch, return_loss=True) self.log( "loss", forward_output["loss"], batch_size=batch["questions"]["input_ids"].size(0), prog_bar=True, ) return forward_output["loss"] def validation_step(self, batch: ModelInputs, batch_idx: int) -> None: forward_output = self.forward(**batch, return_loss=True) self.log( "val_loss", forward_output["loss"], batch_size=batch["questions"]["input_ids"].size(0), ) def test_step(self, batch: ModelInputs, batch_idx: int) -> Any: forward_output = self.forward(**batch, return_loss=True) self.log( "test_loss", forward_output["loss"], batch_size=batch["questions"]["input_ids"].size(0), ) def configure_optimizers(self): if isinstance(self.optimizer_config, DictConfig): param_optimizer = list(self.named_parameters()) no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in param_optimizer if "layer_norm_layer" in n ], "weight_decay": self.hparams.optimizer.weight_decay, "lr": 1e-4, }, { "params": [ p for n, p in param_optimizer if all(nd not in n for nd in no_decay) and "layer_norm_layer" not in n ], "weight_decay": self.hparams.optimizer.weight_decay, }, { "params": [ p for n, p in param_optimizer if "layer_norm_layer" not in n and any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] optimizer = hydra.utils.instantiate( self.optimizer_config, # params=self.parameters(), params=optimizer_grouped_parameters, _convert_="partial", ) else: optimizer = self.optimizer_config if self.lr_scheduler_config is None: return optimizer if isinstance(self.lr_scheduler_config, DictConfig): lr_scheduler = hydra.utils.instantiate( self.lr_scheduler_config, optimizer=optimizer ) else: lr_scheduler = self.lr_scheduler_config lr_scheduler_config = { "scheduler": lr_scheduler, "interval": "step", "frequency": 1, } return [optimizer], [lr_scheduler_config]