File size: 422 Bytes
6d1b6c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
import torch
def metrics(similarity_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
y = torch.arange(len(similarity_matrix)).to(similarity_matrix.device)
img2cap_match_idx = similarity_matrix.argmax(dim=1)
cap2img_match_idx = similarity_matrix.argmax(dim=0)
img_acc = (img2cap_match_idx == y).float().mean()
cap_acc = (cap2img_match_idx == y).float().mean()
return img_acc, cap_acc
|