model-editing / metrics.py
Charles Lin
Add algorithms from efk codebase
e56055d
import torch
from utils import gather_log_probs, mask_hf_labels, masked_mean
def es_sentiment(pre_logits, post_logits, raw_targets, same_sent_mask, NULL_TOKEN=0):
with torch.no_grad():
mask, targ = mask_hf_labels(raw_targets)
pos_mask = same_sent_mask.unsqueeze(-1) * mask
neg_mask = (~same_sent_mask).unsqueeze(-1) * mask
# Compute log likelihoods of pos/neg samples
pre_edit_token_log_probs = gather_log_probs(pre_logits, targ)
post_edit_token_log_probs = gather_log_probs(post_logits, targ)
mean_pos_pre = masked_mean(pre_edit_token_log_probs, pos_mask)
mean_pos_post = masked_mean(post_edit_token_log_probs, pos_mask)
mean_neg_post = masked_mean(post_edit_token_log_probs, neg_mask)
z_sent = (mean_pos_post - mean_neg_post).sigmoid()
z_topic_raw = (mean_pos_post - mean_pos_pre).exp()
z_topic = min(1, z_topic_raw)
es_sent = z_sent * z_topic
return {
"acc_sent": es_sent,
"z_sent": z_sent,
"z_topic": z_topic,
"z_topic_raw": z_topic_raw,
"correct_probs": mean_pos_post,
"wrong_probs": mean_neg_post,
}
# DEPRECATED
def sent_success(pre_edit_probs, post_edit_probs, pos_mask, eps=torch.finfo(torch.float32).eps, batch_size=20):
assert False, "No longer used"
# content_score = post_edit_probs[pos_mask].prod() ** (1/pos_mask.sum()) / (pre_edit_probs[pos_mask]. + eps)
post_pos_avg = post_edit_probs[pos_mask].prod() ** (1 / pos_mask.sum())
pre_pos_avg = pre_edit_probs[pos_mask].prod() ** (1 / pos_mask.sum())
content_score = post_pos_avg / (pre_pos_avg + eps)
z_content = min(1., content_score)
# compute z_sent through a weighting objective
# normalized_probs = post_edit_probs / (post_edit_probs.sum() + eps)
# balancing_factor = 0.5 * ((~pos_mask).float().sum() / pos_mask.float().sum() + 1)
# z_sent_weight = balancing_factor * normalized_probs.dot(pos_mask.float())
post_neg_avg = post_edit_probs[~pos_mask].prod() ** (1 / (~pos_mask).sum())
neg_over_pos = post_neg_avg / (eps + post_pos_avg)
z_sent_weight = 1 / (1 + neg_over_pos)
# compute z_sent through a ranking objective
batch_mask = pos_mask.view(-1, batch_size).long()
sort_idxs = post_edit_probs.view(-1, batch_size).sort(-1, descending=True).indices
ranked_mask = batch_mask.gather(1, sort_idxs)
true_mask = batch_mask.sort(-1, descending=True).values
z_sent_rank = (ranked_mask == true_mask).float().mean()
# compute the final success scores
weight_success = (z_content * z_sent_weight) ** 0.5
rank_success = (z_content * z_sent_rank) ** 0.5
correct_probs = post_edit_probs[pos_mask].mean()
wrong_probs = post_edit_probs[~pos_mask].mean()
return {
"acc_weight": weight_success,
"acc_rank": rank_success,
"rank_score": z_sent_rank,
"weight_score": z_sent_weight,
"content_score": content_score,
"post_edit_probs": post_edit_probs.sum(),
"pre_edit_probs": pre_edit_probs.sum(),
"correct_probs": correct_probs,
"wrong_probs": wrong_probs
}
# def sent_retain(pre_logits, post_logits, sent_mask, batch_size=20, eps=torch.finfo(torch.float32).eps):
# pre_log_probs = pre_logits.log_softmax(-1).gather(-1, all_targ.unsqueeze(-1)).squeeze(-1)
# post_log_probs = post_logits.log_softmax(-1).gather(-1, all_targ.unsqueeze(-1)).squeeze(-1)
# pre_batch = pre_probs.view(-1, batch_size)
# post_batch = post_probs.view(-1, batch_size)
# mask_batch = sent_mask.view(-1, batch_size)
# stats = []
# for pre, post, mask in zip(pre_batch, post_batch, mask_batch):
# avg_pre = pre.prod() ** (1 / pre.numel())
# avg_post = post.prod() ** (1 / post.numel())
# z_avg = min(avg_pre / avg_post, avg_post / avg_pre)
# post_neg_avg = post[~mask].prod() ** (1 / (~mask).sum())
# post_pos_avg = post[mask].prod() ** (1 / mask.sum())
# pre_neg_avg = pre[~mask].prod() ** (1 / (~mask).sum())
# pre_pos_avg = pre[mask].prod() ** (1 / mask.sum())
# post_neg_over_pos = post_neg_avg / (eps + post_pos_avg)
# pre_neg_over_pos = pre_neg_avg / (eps + pre_pos_avg)
# z_post = 1 / (1 + post_neg_over_pos)
# z_pre = 1 / (1 + pre_neg_over_pos)
# z_sent = min(z_post / z_pre, z_pre / z_post)
# stats.append((z_avg * z_sent) ** 0.5)
# return sum(stats) / len(stats)
# For zsRE and F-NLI
def retain_rate(pre_logits, post_logits, mask=None):
if pre_logits.shape[-1] == 1:
pre_logits = pre_logits.squeeze(-1)
if post_logits.shape[-1] == 1:
post_logits = post_logits.squeeze(-1)
assert pre_logits.shape == post_logits.shape
assert pre_logits.shape[0] == mask.shape[0]
if pre_logits.dim() == 1:
# binary classification
pre_preds = pre_logits > 0
post_preds = post_logits > 0
retain = (pre_preds == post_preds).float().mean()
elif pre_logits.dim() == 3:
# sequence modeling
pre_preds = pre_logits.argmax(-1)
post_preds = post_logits.argmax(-1)
match = (pre_preds == post_preds) * mask
retain = (match.sum(-1) == mask.sum(-1)).float().mean()
else:
raise NotImplementedError
return retain.item()