sv-task / src /models /distillbert.py
lamossta's picture
models and inference classes
51620d3
from collections import Counter
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score
from torch import nn
from torch.utils.data import DataLoader
from transformers import Trainer
from src.models.dataset import EntitySentimentDataset
def compute_class_weights(examples: list[dict], n_classes: int) -> torch.Tensor:
counts = Counter(e["label"] for e in examples)
total = sum(counts.values())
weights = [total / (n_classes * counts.get(i, 1)) for i in range(n_classes)]
return torch.tensor(weights, dtype=torch.float)
def focal_loss(
logits: torch.Tensor,
labels: torch.Tensor,
weight: torch.Tensor,
gamma: float = 2.0,
) -> torch.Tensor:
ce = F.cross_entropy(logits, labels, weight=weight, reduction="none")
probs = F.softmax(logits, dim=-1)
pt = probs.gather(1, labels.unsqueeze(1)).squeeze(1)
return ((1 - pt) ** gamma * ce).mean()
class WeightedLossTrainer(Trainer):
def __init__(self, *args, class_weights: torch.Tensor, loss_fn: str = "cross_entropy", focal_gamma: float = 2.0, **kwargs):
super().__init__(*args, **kwargs)
self.class_weights = class_weights
self.loss_fn = loss_fn
self.focal_gamma = focal_gamma
def compute_loss(self, model, inputs, return_outputs: bool = False, **kwargs):
labels = inputs.pop("labels")
outputs = model(**inputs)
w = self.class_weights.to(outputs.logits.device)
if self.loss_fn == "focal":
loss = focal_loss(outputs.logits, labels, weight=w, gamma=self.focal_gamma)
else:
loss = nn.CrossEntropyLoss(weight=w)(outputs.logits, labels)
return (loss, outputs) if return_outputs else loss
def reconstruct_triplets(
yes_probs: np.ndarray, bin_labels: np.ndarray
) -> tuple[list[int], list[int]]:
"""Group consecutive (neg, neu, pos) triplets and take argmax."""
preds3, labels3 = [], []
for i in range(0, len(yes_probs) - 2, 3):
preds3.append(int(np.argmax(yes_probs[i: i + 3])))
labels3.append(int(np.argmax(bin_labels[i: i + 3])))
return preds3, labels3
def make_compute_metrics(mode: str):
if mode in ("marker", "qa_m"):
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
macro_f1 = f1_score(labels, preds, average="macro")
per_class = f1_score(labels, preds, average=None, labels=[0, 1, 2])
return {
"macro_f1": macro_f1,
"f1_negative": per_class[0],
"f1_neutral": per_class[1],
"f1_positive": per_class[2],
}
else:
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
bin_acc = float((preds == labels).mean())
bin_f1 = float(f1_score(labels, preds, average="binary", pos_label=1))
yes_probs = F.softmax(
torch.tensor(logits, dtype=torch.float), dim=-1
)[:, 1].numpy()
preds3, labels3 = reconstruct_triplets(yes_probs, labels)
macro_f1 = float(f1_score(preds3, labels3, average="macro")) \
if preds3 else 0.0
return {
"macro_f1": macro_f1,
"bin_accuracy": bin_acc,
"bin_f1_yes": bin_f1,
}
return compute_metrics
def evaluate_qa_b_test(
model,
tokenizer,
test_exs: list[dict],
max_len: int,
batch_size: int,
device: torch.device,
) -> tuple[float, list[int], list[int]]:
ds = EntitySentimentDataset(test_exs, tokenizer, max_len)
loader = DataLoader(ds, batch_size=batch_size, shuffle=False)
all_yes_probs, all_bin_labels = [], []
model.eval()
with torch.no_grad():
for batch in loader:
logits = model(
input_ids=batch["input_ids"].to(device),
attention_mask=batch["attention_mask"].to(device),
).logits
all_yes_probs.extend(
F.softmax(logits, dim=-1)[:, 1].cpu().tolist()
)
all_bin_labels.extend(batch["labels"].tolist())
preds3, labels3 = reconstruct_triplets(
np.array(all_yes_probs), np.array(all_bin_labels)
)
macro_f1 = f1_score(labels3, preds3, average="macro")
return macro_f1, preds3, labels3