|
import math |
|
|
|
def get_topk_results(predictions, scores, targets, k, all_items=None): |
|
results = [] |
|
B = len(targets) |
|
predictions = [_.split("Response:")[-1] for _ in predictions] |
|
predictions = [_.strip().replace(" ","") for _ in predictions] |
|
|
|
if all_items is not None: |
|
for i, seq in enumerate(predictions): |
|
if seq not in all_items: |
|
scores[i] = -1000 |
|
|
|
for b in range(B): |
|
batch_seqs = predictions[b * k: (b + 1) * k] |
|
batch_scores = scores[b * k: (b + 1) * k] |
|
|
|
pairs = [(a, b) for a, b in zip(batch_seqs, batch_scores)] |
|
sorted_pairs = sorted(pairs, key=lambda x: x[1], reverse=True) |
|
target_item = targets[b] |
|
one_results = [] |
|
for sorted_pred in sorted_pairs: |
|
if sorted_pred[0] == target_item: |
|
one_results.append(1) |
|
else: |
|
one_results.append(0) |
|
|
|
results.append(one_results) |
|
|
|
return results |
|
|
|
def get_metrics_results(topk_results, metrics): |
|
res = {} |
|
for m in metrics: |
|
if m.lower().startswith("hit"): |
|
k = int(m.split("@")[1]) |
|
res[m] = hit_k(topk_results, k) |
|
elif m.lower().startswith("ndcg"): |
|
k = int(m.split("@")[1]) |
|
res[m] = ndcg_k(topk_results, k) |
|
else: |
|
raise NotImplementedError |
|
|
|
return res |
|
|
|
|
|
def ndcg_k(topk_results, k): |
|
|
|
ndcg = 0.0 |
|
for row in topk_results: |
|
res = row[:k] |
|
one_ndcg = 0.0 |
|
for i in range(len(res)): |
|
one_ndcg += res[i] / math.log(i + 2, 2) |
|
ndcg += one_ndcg |
|
return ndcg |
|
|
|
|
|
def hit_k(topk_results, k): |
|
hit = 0.0 |
|
for row in topk_results: |
|
res = row[:k] |
|
if sum(res) > 0: |
|
hit += 1 |
|
return hit |
|
|
|
|