File size: 422 Bytes
6d1b6c6
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch


def metrics(similarity_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    y = torch.arange(len(similarity_matrix)).to(similarity_matrix.device)
    img2cap_match_idx = similarity_matrix.argmax(dim=1)
    cap2img_match_idx = similarity_matrix.argmax(dim=0)

    img_acc = (img2cap_match_idx == y).float().mean()
    cap_acc = (cap2img_match_idx == y).float().mean()

    return img_acc, cap_acc