import torch import torch.nn.functional as F from metrics import es_sentiment from utils import gather_log_probs, mask_hf_labels, masked_mean def balanced_bce(log_probs, labels, eps=torch.finfo(torch.float32).eps): assert labels.max() <= 1 assert labels.min() >= 0 pos_losses = -log_probs[labels == 1] neg_probs = 1 - log_probs.exp() neg_probs[neg_probs == 0] += eps # for numerical stability neg_losses = -neg_probs.log()[labels == 0] pos_loss = pos_losses.mean() if pos_losses.numel() > 0 else 0 neg_loss = neg_losses.mean() if neg_losses.numel() > 0 else 0 return pos_loss + neg_loss def kl_loc_loss(pre, post, mask=None): pre = pre.to(torch.float32) post = post.to(torch.float32) sequence = pre.dim() == 3 pre_ = pre.view(-1, pre.shape[-1]) post_ = post.view(pre_.shape) assert pre_.shape[0] == post_.shape[0] if not sequence: if pre_.shape[-1] == 1: # No masking needed for binary classification return (pre.sigmoid() * (F.logsigmoid(pre) - F.logsigmoid(post))).mean() + ( (-pre).sigmoid() * (F.logsigmoid(-pre) - F.logsigmoid(-post)) ).mean() else: # We have sequences of predictions; masking needed if pre_.shape[-1] > 1: assert mask is not None mask_ = mask.view(pre_.shape[0]) kl = (pre_.softmax(-1) * (pre_.log_softmax(-1) - post_.log_softmax(-1))).sum(-1) return (kl * mask_).sum() / mask_.sum() raise NotImplementedError def binary_log_probs(pred, targ, should_reduce=True): assert targ.max() <= 1 assert targ.min() >= 0 neg_mask = torch.ones_like(pred) neg_mask[targ == 0] *= -1 pred = pred * neg_mask log_probs = F.logsigmoid(pred) acc = (log_probs.exp() > 0.5).float() if should_reduce: acc = acc.mean() return { "acc": acc, "log_prob": log_probs.mean(), "prob": log_probs.exp().mean(), "nll": -log_probs.mean(), "n_tokens": log_probs.shape[0] } def multiclass_log_probs( pred, raw_targets, shift=True, eps=torch.finfo(torch.float32).eps, should_reduce=True, **kwargs, ): NULL_TOKEN = 0 # a placeholder used for masked target locations pred = pred.clone() mask, targ = mask_hf_labels(raw_targets) if shift and pred.dim() == 3: # Dealing with sequences pred = pred[:, :-1] # Remove last prediction in sequence targ = targ[:, 1:] # Shift to align predictions and targets unmasked_log_probs = gather_log_probs(pred, targ) pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN) correct = pred_ids == targ if pred.dim() == 3: correct = (pred_ids == targ).all(-1) # We want to get the whole sequence right acc = correct.float() if should_reduce: acc = acc.mean() if "inner_sent" in kwargs: # Only use outer samples with the same sentiment as the inner sample same_sent_mask = torch.tensor([i == o for i, o in zip(kwargs["inner_sent"], kwargs["outer_sent"])], device=pred.device) good_mask = mask * same_sent_mask.unsqueeze(-1) bad_mask = mask * (~same_sent_mask.unsqueeze(-1)) good_log_prob = masked_mean(unmasked_log_probs, good_mask) bad_log_prob = masked_mean((1 - unmasked_log_probs.exp() + eps).log(), bad_mask) n_tokens = good_mask.float().sum() avg_log_prob = good_log_prob if kwargs["unlikelihood"]: nll = -good_log_prob - bad_log_prob else: nll = -good_log_prob else: n_tokens = mask.float().sum() avg_log_prob = (unmasked_log_probs * mask.float()).sum() / n_tokens nll = -avg_log_prob info_dict = { "acc": acc, "log_prob": avg_log_prob, "prob": avg_log_prob.exp(), "n_tokens": n_tokens, "nll": nll } if "inner_sent" in kwargs: info_dict.update(es_sentiment(kwargs["pre_edit_logits"], kwargs["post_edit_logits"], raw_targets, same_sent_mask)) return info_dict def masked_log_probs(pred, targ, shift=True, **kwargs): pred = pred.to(torch.float32) if not (pred.dim() == 2 or pred.dim() == 3): raise RuntimeError(f"Expected pred to have 2 or 3 dimensions, got {pred.shape}") if pred.shape[-1] == 1: should_reduce = True if "should_reduce" in kwargs: should_reduce = kwargs["should_reduce"] return binary_log_probs(pred, targ, should_reduce=should_reduce) else: return multiclass_log_probs(pred, targ, shift=shift, **kwargs) def test_masked_log_probs(): print() N = 10000 pred = torch.randn(10, 15, N) targ = torch.randint(0, N, (10, 15)) true_pred = pred.clone() true_pred.scatter_(2, targ.unsqueeze(-1), 5) true_pred = true_pred.roll(-1, 1) half_pred = true_pred.clone() mask = torch.arange(10) % 2 == 0 half_pred[mask] = pred[mask] pred_ = pred.clone() true_pred_ = true_pred.clone() half_pred_ = half_pred.clone() targ_ = targ.clone() print(masked_log_probs(pred, targ, return_acc=True)) print(masked_log_probs(true_pred, targ, return_acc=True)) print(masked_log_probs(half_pred, targ, return_acc=True)) assert (pred == pred_).all() assert (targ == targ_).all() assert (half_pred == half_pred_).all() assert (true_pred == true_pred_).all() import pdb; pdb.set_trace() pred = torch.randn(1000, 15, 1) targ = torch.randint(0, 2, (1000, 15)) print(masked_log_probs(pred, targ, return_acc=True)) if __name__ == "__main__": torch.manual_seed(0) test_masked_log_probs()