Polos-Demo / validate /pascal50s.py
yuwd's picture
init
03f6091
raw
history blame
4.15 kB
# from https://github.com/jmhessel/clipscore/issues/4
import torch
import random
import scipy
import os
from tqdm import tqdm
class Pascal50sDataset(torch.utils.data.Dataset):
def __init__(self,
root: str = "data/Pascal-50s/",
media_size: int = 224,
voc_path: str = "data/VOC2010/"):
super().__init__()
self.voc_path = voc_path
self.fix_seed()
self.read_data(root)
self.read_score(root)
self.idx2cat = {1: 'HC', 2: 'HI', 3: 'HM', 4: 'MM'}
@staticmethod
def loadmat(path):
return scipy.io.loadmat(path)
def fix_seed(self, seed=42):
torch.manual_seed(seed)
random.seed(seed)
def read_data(self, root):
mat = self.loadmat(
os.path.join(root, "pyCIDErConsensus/pair_pascal.mat"))
self.data = mat["new_input"][0]
self.categories = mat["category"][0]
# sanity check
c = torch.Tensor(mat["new_data"])
hc = (c.sum(dim=-1) == 12).int()
hi = (c.sum(dim=-1) == 13).int()
hm = ((c < 6).sum(dim=-1) == 1).int()
mm = ((c < 6).sum(dim=-1) == 2).int()
assert 1000 == hc.sum()
assert 1000 == hi.sum()
assert 1000 == hm.sum()
assert 1000 == mm.sum()
assert (hc + hi + hm + mm).sum() == self.categories.shape[0]
chk = (torch.Tensor(self.categories) - hc - hi * 2 - hm * 3 - mm * 4)
assert 0 == chk.abs().sum(), chk
def read_score(self, root):
mat = self.loadmat(
os.path.join(root, "pyCIDErConsensus/consensus_pascal.mat"))
data = mat["triplets"][0]
self.labels = []
self.references = []
for i in range(len(self)):
votes = {}
refs = []
for j in range(i * 48, (i + 1) * 48):
a,b,c,d = [x[0][0] for x in data[j]]
key = b[0].strip() if 1 == d else c[0].strip()
refs.append(a[0].strip())
votes[key] = votes.get(key, 0) + 1
assert 2 >= len(votes.keys()), votes
assert len(votes.keys()) > 0
try:
vote_a = votes.get(self.data[i][1][0].strip(), 0)
vote_b = votes.get(self.data[i][2][0].strip(), 0)
except KeyError:
print("warning: data mismatch!")
print(f"a: {self.data[i][1][0].strip()}")
print(f"b: {self.data[i][2][0].strip()}")
print(votes)
exit()
# Ties are broken randomly.
label = 0 if vote_a > vote_b + random.random() - .5 else 1 # a == bの場合は0.5の確率で0か1を選ぶ
self.labels.append(label)
self.references.append(refs)
def __len__(self):
return len(self.data)
def get_image_path(self, filename: str):
path = os.path.join(self.voc_path, "JPEGImages")
return os.path.join(path, filename)
def __getitem__(self, idx: int):
vid, a, b = [x[0] for x in self.data[idx]]
label = self.labels[idx]
img_path = self.get_image_path(vid)
a = a.strip()
b = b.strip()
references = self.references[idx]
category = self.categories[idx]
category_str = self.idx2cat[category]
return img_path, a, b, references, category_str, label
def sanity_check(detail=False):
# sanity check
dataset = Pascal50sDataset(root="pascal/", voc_path="pascal/VOCdevkit/VOC2010")
one_sample = dataset[0]
assert one_sample is not None
dprint = lambda *args, **kwargs: print(*args, **kwargs) if detail else None
for it, one_sample in enumerate(tqdm(dataset)):
dprint("="*20)
dprint("sample:",it)
dprint("="*20)
img_path, a, b, references, category, label = one_sample
assert os.path.exists(img_path)
dprint("img_path:", img_path)
dprint("a:", a)
dprint("b:", b)
dprint("references:", references)
dprint("category:", category)
dprint("label:", label)
if __name__ == "__main__":
sanity_check(detail=False)