tiny_clip / src /metrics.py
sachin's picture
Initial training code
6d1b6c6
raw
history blame
No virus
422 Bytes
import torch
def metrics(similarity_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
y = torch.arange(len(similarity_matrix)).to(similarity_matrix.device)
img2cap_match_idx = similarity_matrix.argmax(dim=1)
cap2img_match_idx = similarity_matrix.argmax(dim=0)
img_acc = (img2cap_match_idx == y).float().mean()
cap_acc = (cap2img_match_idx == y).float().mean()
return img_acc, cap_acc