waidhoferj's picture
fixed metrics and weights
1c22425
raw
history blame
2.73 kB
import torch.nn as nn
import torch
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
class LabelWeightedBCELoss(nn.Module):
"""
Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution.
Allows for the weighing of each probability distribution wrt loss.
"""
def __init__(self, label_weights: torch.Tensor, reduction="mean"):
super().__init__()
self.label_weights = label_weights
match reduction:
case "mean":
self.reduction = torch.mean
case "sum":
self.reduction = torch.sum
def _log(self, x: torch.Tensor) -> torch.Tensor:
return torch.clamp_min(torch.log(x), -100)
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
losses = -self.label_weights * (
target * self._log(input) + (1 - target) * self._log(1 - input)
)
return self.reduction(losses)
# TODO: Code a onehot
def calculate_metrics(
pred, target, threshold=0.5, prefix="", multi_label=True
) -> dict[str, torch.Tensor]:
target = target.detach().cpu().numpy()
pred = pred.detach().cpu()
pred = nn.functional.softmax(pred, dim=1)
pred = pred.numpy()
params = {
"y_true": target if multi_label else target.argmax(1),
"y_pred": np.array(pred > threshold, dtype=float)
if multi_label
else pred.argmax(1),
"zero_division": 0,
"average": "macro",
}
metrics = {
"precision": precision_score(**params),
"recall": recall_score(**params),
"f1": f1_score(**params),
"accuracy": accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
}
return {
prefix + k: torch.tensor(v, dtype=torch.float32) for k, v in metrics.items()
}
class EarlyStopping:
def __init__(self, patience=0):
self.patience = patience
self.last_measure = np.inf
self.consecutive_increase = 0
def step(self, val) -> bool:
if self.last_measure <= val:
self.consecutive_increase += 1
else:
self.consecutive_increase = 0
self.last_measure = val
return self.patience < self.consecutive_increase
def get_id_label_mapping(labels: list[str]) -> tuple[dict, dict]:
id2label = {str(i): label for i, label in enumerate(labels)}
label2id = {label: str(i) for i, label in enumerate(labels)}
return id2label, label2id
def compute_hf_metrics(eval_pred):
predictions = np.argmax(eval_pred.predictions, axis=1)
return accuracy_score(y_true=eval_pred.label_ids, y_pred=predictions)