import torch | |
def metrics(similarity_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
y = torch.arange(len(similarity_matrix)).to(similarity_matrix.device) | |
img2cap_match_idx = similarity_matrix.argmax(dim=1) | |
cap2img_match_idx = similarity_matrix.argmax(dim=0) | |
img_acc = (img2cap_match_idx == y).float().mean() | |
cap_acc = (cap2img_match_idx == y).float().mean() | |
return img_acc, cap_acc | |