Spaces:
Running
Running
import logging | |
from typing import Dict, List, Optional | |
import lightning as pl | |
import torch | |
from lightning.pytorch.trainer.states import RunningStage | |
from sklearn.metrics import label_ranking_average_precision_score | |
from relik.common.log import get_console_logger, get_logger | |
from relik.retriever.callbacks.base import DEFAULT_STAGES, NLPTemplateCallback | |
console_logger = get_console_logger() | |
logger = get_logger(__name__, level=logging.INFO) | |
class RecallAtKEvaluationCallback(NLPTemplateCallback): | |
""" | |
Computes the recall at k for the predictions. Recall at k is computed as the number of | |
correct predictions in the top k predictions divided by the total number of correct | |
predictions. | |
Args: | |
k (`int`): | |
The number of predictions to consider. | |
prefix (`str`, `optional`): | |
The prefix to add to the metrics. | |
verbose (`bool`, `optional`, defaults to `False`): | |
Whether to log the metrics. | |
prog_bar (`bool`, `optional`, defaults to `True`): | |
Whether to log the metrics to the progress bar. | |
""" | |
def __init__( | |
self, | |
k: int = 100, | |
prefix: Optional[str] = None, | |
verbose: bool = False, | |
prog_bar: bool = True, | |
*args, | |
**kwargs, | |
): | |
super().__init__() | |
self.k = k | |
self.prefix = prefix | |
self.verbose = verbose | |
self.prog_bar = prog_bar | |
def __call__( | |
self, | |
trainer: pl.Trainer, | |
pl_module: pl.LightningModule, | |
predictions: Dict, | |
*args, | |
**kwargs, | |
) -> dict: | |
""" | |
Computes the recall at k for the predictions. | |
Args: | |
trainer (:obj:`lightning.trainer.trainer.Trainer`): | |
The trainer object. | |
pl_module (:obj:`lightning.core.lightning.LightningModule`): | |
The lightning module. | |
predictions (:obj:`Dict`): | |
The predictions. | |
Returns: | |
:obj:`Dict`: The computed metrics. | |
""" | |
if self.verbose: | |
logger.info(f"Computing recall@{self.k}") | |
# metrics to return | |
metrics = {} | |
stage = trainer.state.stage | |
if stage not in DEFAULT_STAGES: | |
raise ValueError( | |
f"Stage {stage} not supported, only `validate` and `test` are supported." | |
) | |
for dataloader_idx, samples in predictions.items(): | |
hits, total = 0, 0 | |
for sample in samples: | |
# compute the recall at k | |
# cut the predictions to the first k elements | |
predictions = sample["predictions"][: self.k] | |
hits += len(set(predictions) & set(sample["gold"])) | |
total += len(set(sample["gold"])) | |
# compute the mean recall at k | |
recall_at_k = hits / total | |
metrics[f"recall@{self.k}_{dataloader_idx}"] = recall_at_k | |
metrics[f"recall@{self.k}"] = sum(metrics.values()) / len(metrics) | |
if self.prefix is not None: | |
metrics = {f"{self.prefix}_{k}": v for k, v in metrics.items()} | |
else: | |
metrics = {f"{stage.value}_{k}": v for k, v in metrics.items()} | |
pl_module.log_dict( | |
metrics, on_step=False, on_epoch=True, prog_bar=self.prog_bar | |
) | |
if self.verbose: | |
logger.info( | |
f"Recall@{self.k} on {stage.value}: {metrics[f'{stage.value}_recall@{self.k}']}" | |
) | |
return metrics | |
class AvgRankingEvaluationCallback(NLPTemplateCallback): | |
""" | |
Computes the average ranking of the gold label in the predictions. Average ranking is | |
computed as the average of the rank of the gold label in the predictions. | |
Args: | |
k (`int`): | |
The number of predictions to consider. | |
prefix (`str`, `optional`): | |
The prefix to add to the metrics. | |
stages (`List[str]`, `optional`): | |
The stages to compute the metrics on. Defaults to `["validate", "test"]`. | |
verbose (`bool`, `optional`, defaults to `False`): | |
Whether to log the metrics. | |
""" | |
def __init__( | |
self, | |
k: int, | |
prefix: Optional[str] = None, | |
stages: Optional[List[str]] = None, | |
verbose: bool = True, | |
*args, | |
**kwargs, | |
): | |
super().__init__() | |
self.k = k | |
self.prefix = prefix | |
self.verbose = verbose | |
self.stages = ( | |
[RunningStage(stage) for stage in stages] if stages else DEFAULT_STAGES | |
) | |
def __call__( | |
self, | |
trainer: pl.Trainer, | |
pl_module: pl.LightningModule, | |
predictions: Dict, | |
*args, | |
**kwargs, | |
) -> dict: | |
""" | |
Computes the average ranking of the gold label in the predictions. | |
Args: | |
trainer (:obj:`lightning.trainer.trainer.Trainer`): | |
The trainer object. | |
pl_module (:obj:`lightning.core.lightning.LightningModule`): | |
The lightning module. | |
predictions (:obj:`Dict`): | |
The predictions. | |
Returns: | |
:obj:`Dict`: The computed metrics. | |
""" | |
if not predictions: | |
logger.warning("No predictions to compute the AVG Ranking metrics.") | |
return {} | |
if self.verbose: | |
logger.info(f"Computing AVG Ranking@{self.k}") | |
# metrics to return | |
metrics = {} | |
stage = trainer.state.stage | |
if stage not in self.stages: | |
raise ValueError( | |
f"Stage `{stage}` not supported, only `validate` and `test` are supported." | |
) | |
for dataloader_idx, samples in predictions.items(): | |
rankings = [] | |
for sample in samples: | |
window_candidates = sample["predictions"][: self.k] | |
window_labels = sample["gold"] | |
for wl in window_labels: | |
if wl in window_candidates: | |
rankings.append(window_candidates.index(wl) + 1) | |
avg_ranking = sum(rankings) / len(rankings) if len(rankings) > 0 else 0 | |
metrics[f"avg_ranking@{self.k}_{dataloader_idx}"] = avg_ranking | |
if len(metrics) == 0: | |
metrics[f"avg_ranking@{self.k}"] = 0 | |
else: | |
metrics[f"avg_ranking@{self.k}"] = sum(metrics.values()) / len(metrics) | |
prefix = self.prefix or stage.value | |
metrics = { | |
f"{prefix}_{k}": torch.as_tensor(v, dtype=torch.float32) | |
for k, v in metrics.items() | |
} | |
pl_module.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=False) | |
if self.verbose: | |
logger.info( | |
f"AVG Ranking@{self.k} on {prefix}: {metrics[f'{prefix}_avg_ranking@{self.k}']}" | |
) | |
return metrics | |
class LRAPEvaluationCallback(NLPTemplateCallback): | |
def __init__( | |
self, | |
k: int = 100, | |
prefix: Optional[str] = None, | |
verbose: bool = False, | |
prog_bar: bool = True, | |
*args, | |
**kwargs, | |
): | |
super().__init__() | |
self.k = k | |
self.prefix = prefix | |
self.verbose = verbose | |
self.prog_bar = prog_bar | |
def __call__( | |
self, | |
trainer: pl.Trainer, | |
pl_module: pl.LightningModule, | |
predictions: Dict, | |
*args, | |
**kwargs, | |
) -> dict: | |
if self.verbose: | |
logger.info(f"Computing recall@{self.k}") | |
# metrics to return | |
metrics = {} | |
stage = trainer.state.stage | |
if stage not in DEFAULT_STAGES: | |
raise ValueError( | |
f"Stage {stage} not supported, only `validate` and `test` are supported." | |
) | |
for dataloader_idx, samples in predictions.items(): | |
scores = [sample["scores"][: self.k] for sample in samples] | |
golds = [sample["gold"] for sample in samples] | |
# compute the mean recall at k | |
lrap = label_ranking_average_precision_score(golds, scores) | |
metrics[f"lrap@{self.k}_{dataloader_idx}"] = lrap | |
metrics[f"lrap@{self.k}"] = sum(metrics.values()) / len(metrics) | |
prefix = self.prefix or stage.value | |
metrics = { | |
f"{prefix}_{k}": torch.as_tensor(v, dtype=torch.float32) | |
for k, v in metrics.items() | |
} | |
pl_module.log_dict( | |
metrics, on_step=False, on_epoch=True, prog_bar=self.prog_bar | |
) | |
if self.verbose: | |
logger.info( | |
f"Recall@{self.k} on {stage.value}: {metrics[f'{stage.value}_recall@{self.k}']}" | |
) | |
return metrics | |