model-editing / losses.py
Charles Lin
Add algorithms from efk codebase
e56055d
import torch
import torch.nn.functional as F
from metrics import es_sentiment
from utils import gather_log_probs, mask_hf_labels, masked_mean
def balanced_bce(log_probs, labels, eps=torch.finfo(torch.float32).eps):
assert labels.max() <= 1
assert labels.min() >= 0
pos_losses = -log_probs[labels == 1]
neg_probs = 1 - log_probs.exp()
neg_probs[neg_probs == 0] += eps # for numerical stability
neg_losses = -neg_probs.log()[labels == 0]
pos_loss = pos_losses.mean() if pos_losses.numel() > 0 else 0
neg_loss = neg_losses.mean() if neg_losses.numel() > 0 else 0
return pos_loss + neg_loss
def kl_loc_loss(pre, post, mask=None):
pre = pre.to(torch.float32)
post = post.to(torch.float32)
sequence = pre.dim() == 3
pre_ = pre.view(-1, pre.shape[-1])
post_ = post.view(pre_.shape)
assert pre_.shape[0] == post_.shape[0]
if not sequence:
if pre_.shape[-1] == 1: # No masking needed for binary classification
return (pre.sigmoid() * (F.logsigmoid(pre) - F.logsigmoid(post))).mean() + (
(-pre).sigmoid() * (F.logsigmoid(-pre) - F.logsigmoid(-post))
).mean()
else: # We have sequences of predictions; masking needed
if pre_.shape[-1] > 1:
assert mask is not None
mask_ = mask.view(pre_.shape[0])
kl = (pre_.softmax(-1) * (pre_.log_softmax(-1) - post_.log_softmax(-1))).sum(-1)
return (kl * mask_).sum() / mask_.sum()
raise NotImplementedError
def binary_log_probs(pred, targ, should_reduce=True):
assert targ.max() <= 1
assert targ.min() >= 0
neg_mask = torch.ones_like(pred)
neg_mask[targ == 0] *= -1
pred = pred * neg_mask
log_probs = F.logsigmoid(pred)
acc = (log_probs.exp() > 0.5).float()
if should_reduce:
acc = acc.mean()
return {
"acc": acc,
"log_prob": log_probs.mean(),
"prob": log_probs.exp().mean(),
"nll": -log_probs.mean(),
"n_tokens": log_probs.shape[0]
}
def multiclass_log_probs(
pred,
raw_targets,
shift=True,
eps=torch.finfo(torch.float32).eps,
should_reduce=True,
**kwargs,
):
NULL_TOKEN = 0 # a placeholder used for masked target locations
pred = pred.clone()
mask, targ = mask_hf_labels(raw_targets)
if shift and pred.dim() == 3: # Dealing with sequences
pred = pred[:, :-1] # Remove last prediction in sequence
targ = targ[:, 1:] # Shift to align predictions and targets
unmasked_log_probs = gather_log_probs(pred, targ)
pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN)
correct = pred_ids == targ
if pred.dim() == 3:
correct = (pred_ids == targ).all(-1) # We want to get the whole sequence right
acc = correct.float()
if should_reduce:
acc = acc.mean()
if "inner_sent" in kwargs:
# Only use outer samples with the same sentiment as the inner sample
same_sent_mask = torch.tensor([i == o for i, o in zip(kwargs["inner_sent"], kwargs["outer_sent"])], device=pred.device)
good_mask = mask * same_sent_mask.unsqueeze(-1)
bad_mask = mask * (~same_sent_mask.unsqueeze(-1))
good_log_prob = masked_mean(unmasked_log_probs, good_mask)
bad_log_prob = masked_mean((1 - unmasked_log_probs.exp() + eps).log(), bad_mask)
n_tokens = good_mask.float().sum()
avg_log_prob = good_log_prob
if kwargs["unlikelihood"]:
nll = -good_log_prob - bad_log_prob
else:
nll = -good_log_prob
else:
n_tokens = mask.float().sum()
avg_log_prob = (unmasked_log_probs * mask.float()).sum() / n_tokens
nll = -avg_log_prob
info_dict = {
"acc": acc,
"log_prob": avg_log_prob,
"prob": avg_log_prob.exp(),
"n_tokens": n_tokens,
"nll": nll
}
if "inner_sent" in kwargs:
info_dict.update(es_sentiment(kwargs["pre_edit_logits"],
kwargs["post_edit_logits"],
raw_targets,
same_sent_mask))
return info_dict
def masked_log_probs(pred, targ, shift=True, **kwargs):
pred = pred.to(torch.float32)
if not (pred.dim() == 2 or pred.dim() == 3):
raise RuntimeError(f"Expected pred to have 2 or 3 dimensions, got {pred.shape}")
if pred.shape[-1] == 1:
should_reduce = True
if "should_reduce" in kwargs:
should_reduce = kwargs["should_reduce"]
return binary_log_probs(pred, targ, should_reduce=should_reduce)
else:
return multiclass_log_probs(pred, targ, shift=shift, **kwargs)
def test_masked_log_probs():
print()
N = 10000
pred = torch.randn(10, 15, N)
targ = torch.randint(0, N, (10, 15))
true_pred = pred.clone()
true_pred.scatter_(2, targ.unsqueeze(-1), 5)
true_pred = true_pred.roll(-1, 1)
half_pred = true_pred.clone()
mask = torch.arange(10) % 2 == 0
half_pred[mask] = pred[mask]
pred_ = pred.clone()
true_pred_ = true_pred.clone()
half_pred_ = half_pred.clone()
targ_ = targ.clone()
print(masked_log_probs(pred, targ, return_acc=True))
print(masked_log_probs(true_pred, targ, return_acc=True))
print(masked_log_probs(half_pred, targ, return_acc=True))
assert (pred == pred_).all()
assert (targ == targ_).all()
assert (half_pred == half_pred_).all()
assert (true_pred == true_pred_).all()
import pdb; pdb.set_trace()
pred = torch.randn(1000, 15, 1)
targ = torch.randint(0, 2, (1000, 15))
print(masked_log_probs(pred, targ, return_acc=True))
if __name__ == "__main__":
torch.manual_seed(0)
test_masked_log_probs()