| from collections import Counter |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from sklearn.metrics import f1_score |
| from torch import nn |
| from torch.utils.data import DataLoader |
| from transformers import Trainer |
| from src.models.dataset import EntitySentimentDataset |
|
|
|
|
| def compute_class_weights(examples: list[dict], n_classes: int) -> torch.Tensor: |
| counts = Counter(e["label"] for e in examples) |
| total = sum(counts.values()) |
| weights = [total / (n_classes * counts.get(i, 1)) for i in range(n_classes)] |
| return torch.tensor(weights, dtype=torch.float) |
|
|
|
|
| def focal_loss( |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| weight: torch.Tensor, |
| gamma: float = 2.0, |
| ) -> torch.Tensor: |
| ce = F.cross_entropy(logits, labels, weight=weight, reduction="none") |
| probs = F.softmax(logits, dim=-1) |
| pt = probs.gather(1, labels.unsqueeze(1)).squeeze(1) |
| return ((1 - pt) ** gamma * ce).mean() |
|
|
|
|
| class WeightedLossTrainer(Trainer): |
|
|
| def __init__(self, *args, class_weights: torch.Tensor, loss_fn: str = "cross_entropy", focal_gamma: float = 2.0, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.class_weights = class_weights |
| self.loss_fn = loss_fn |
| self.focal_gamma = focal_gamma |
|
|
| def compute_loss(self, model, inputs, return_outputs: bool = False, **kwargs): |
| labels = inputs.pop("labels") |
| outputs = model(**inputs) |
| w = self.class_weights.to(outputs.logits.device) |
| if self.loss_fn == "focal": |
| loss = focal_loss(outputs.logits, labels, weight=w, gamma=self.focal_gamma) |
| else: |
| loss = nn.CrossEntropyLoss(weight=w)(outputs.logits, labels) |
| return (loss, outputs) if return_outputs else loss |
|
|
|
|
| def reconstruct_triplets( |
| yes_probs: np.ndarray, bin_labels: np.ndarray |
| ) -> tuple[list[int], list[int]]: |
| """Group consecutive (neg, neu, pos) triplets and take argmax.""" |
| preds3, labels3 = [], [] |
| for i in range(0, len(yes_probs) - 2, 3): |
| preds3.append(int(np.argmax(yes_probs[i: i + 3]))) |
| labels3.append(int(np.argmax(bin_labels[i: i + 3]))) |
| return preds3, labels3 |
|
|
|
|
| def make_compute_metrics(mode: str): |
| if mode in ("marker", "qa_m"): |
| def compute_metrics(eval_pred): |
| logits, labels = eval_pred |
| preds = np.argmax(logits, axis=-1) |
| macro_f1 = f1_score(labels, preds, average="macro") |
| per_class = f1_score(labels, preds, average=None, labels=[0, 1, 2]) |
| return { |
| "macro_f1": macro_f1, |
| "f1_negative": per_class[0], |
| "f1_neutral": per_class[1], |
| "f1_positive": per_class[2], |
| } |
| else: |
| def compute_metrics(eval_pred): |
| logits, labels = eval_pred |
| preds = np.argmax(logits, axis=-1) |
| bin_acc = float((preds == labels).mean()) |
| bin_f1 = float(f1_score(labels, preds, average="binary", pos_label=1)) |
|
|
| yes_probs = F.softmax( |
| torch.tensor(logits, dtype=torch.float), dim=-1 |
| )[:, 1].numpy() |
|
|
| preds3, labels3 = reconstruct_triplets(yes_probs, labels) |
|
|
| macro_f1 = float(f1_score(preds3, labels3, average="macro")) \ |
| if preds3 else 0.0 |
| return { |
| "macro_f1": macro_f1, |
| "bin_accuracy": bin_acc, |
| "bin_f1_yes": bin_f1, |
| } |
|
|
| return compute_metrics |
|
|
|
|
| def evaluate_qa_b_test( |
| model, |
| tokenizer, |
| test_exs: list[dict], |
| max_len: int, |
| batch_size: int, |
| device: torch.device, |
| ) -> tuple[float, list[int], list[int]]: |
| ds = EntitySentimentDataset(test_exs, tokenizer, max_len) |
| loader = DataLoader(ds, batch_size=batch_size, shuffle=False) |
|
|
| all_yes_probs, all_bin_labels = [], [] |
| model.eval() |
| with torch.no_grad(): |
| for batch in loader: |
| logits = model( |
| input_ids=batch["input_ids"].to(device), |
| attention_mask=batch["attention_mask"].to(device), |
| ).logits |
| all_yes_probs.extend( |
| F.softmax(logits, dim=-1)[:, 1].cpu().tolist() |
| ) |
| all_bin_labels.extend(batch["labels"].tolist()) |
|
|
| preds3, labels3 = reconstruct_triplets( |
| np.array(all_yes_probs), np.array(all_bin_labels) |
| ) |
|
|
| macro_f1 = f1_score(labels3, preds3, average="macro") |
| return macro_f1, preds3, labels3 |
|
|