CarlosMalaga's picture
Upload 201 files
2f044c1 verified
raw
history blame
No virus
4.24 kB
from typing import Any, Union
import hydra
import lightning as pl
import torch
from omegaconf import DictConfig
from relik.retriever.common.model_inputs import ModelInputs
class GoldenRetrieverPLModule(pl.LightningModule):
def __init__(
self,
model: Union[torch.nn.Module, DictConfig],
optimizer: Union[torch.optim.Optimizer, DictConfig],
lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, DictConfig] = None,
*args,
**kwargs,
) -> None:
super().__init__()
self.save_hyperparameters(ignore=["model"])
if isinstance(model, DictConfig):
self.model = hydra.utils.instantiate(model)
else:
self.model = model
self.optimizer_config = optimizer
self.lr_scheduler_config = lr_scheduler
def forward(self, **kwargs) -> dict:
"""
Method for the forward pass.
'training_step', 'validation_step' and 'test_step' should call
this method in order to compute the output predictions and the loss.
Returns:
output_dict: forward output containing the predictions (output logits ecc...) and the loss if any.
"""
return self.model(**kwargs)
def training_step(self, batch: ModelInputs, batch_idx: int) -> torch.Tensor:
forward_output = self.forward(**batch, return_loss=True)
self.log(
"loss",
forward_output["loss"],
batch_size=batch["questions"]["input_ids"].size(0),
prog_bar=True,
)
return forward_output["loss"]
def validation_step(self, batch: ModelInputs, batch_idx: int) -> None:
forward_output = self.forward(**batch, return_loss=True)
self.log(
"val_loss",
forward_output["loss"],
batch_size=batch["questions"]["input_ids"].size(0),
)
def test_step(self, batch: ModelInputs, batch_idx: int) -> Any:
forward_output = self.forward(**batch, return_loss=True)
self.log(
"test_loss",
forward_output["loss"],
batch_size=batch["questions"]["input_ids"].size(0),
)
def configure_optimizers(self):
if isinstance(self.optimizer_config, DictConfig):
param_optimizer = list(self.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in param_optimizer if "layer_norm_layer" in n
],
"weight_decay": self.hparams.optimizer.weight_decay,
"lr": 1e-4,
},
{
"params": [
p
for n, p in param_optimizer
if all(nd not in n for nd in no_decay)
and "layer_norm_layer" not in n
],
"weight_decay": self.hparams.optimizer.weight_decay,
},
{
"params": [
p
for n, p in param_optimizer
if "layer_norm_layer" not in n
and any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
optimizer = hydra.utils.instantiate(
self.optimizer_config,
# params=self.parameters(),
params=optimizer_grouped_parameters,
_convert_="partial",
)
else:
optimizer = self.optimizer_config
if self.lr_scheduler_config is None:
return optimizer
if isinstance(self.lr_scheduler_config, DictConfig):
lr_scheduler = hydra.utils.instantiate(
self.lr_scheduler_config, optimizer=optimizer
)
else:
lr_scheduler = self.lr_scheduler_config
lr_scheduler_config = {
"scheduler": lr_scheduler,
"interval": "step",
"frequency": 1,
}
return [optimizer], [lr_scheduler_config]