Spaces:
Running
Running
# from https://github.com/jmhessel/clipscore/issues/4 | |
import torch | |
import random | |
import scipy | |
import os | |
from tqdm import tqdm | |
class Pascal50sDataset(torch.utils.data.Dataset): | |
def __init__(self, | |
root: str = "data/Pascal-50s/", | |
media_size: int = 224, | |
voc_path: str = "data/VOC2010/"): | |
super().__init__() | |
self.voc_path = voc_path | |
self.fix_seed() | |
self.read_data(root) | |
self.read_score(root) | |
self.idx2cat = {1: 'HC', 2: 'HI', 3: 'HM', 4: 'MM'} | |
def loadmat(path): | |
return scipy.io.loadmat(path) | |
def fix_seed(self, seed=42): | |
torch.manual_seed(seed) | |
random.seed(seed) | |
def read_data(self, root): | |
mat = self.loadmat( | |
os.path.join(root, "pyCIDErConsensus/pair_pascal.mat")) | |
self.data = mat["new_input"][0] | |
self.categories = mat["category"][0] | |
# sanity check | |
c = torch.Tensor(mat["new_data"]) | |
hc = (c.sum(dim=-1) == 12).int() | |
hi = (c.sum(dim=-1) == 13).int() | |
hm = ((c < 6).sum(dim=-1) == 1).int() | |
mm = ((c < 6).sum(dim=-1) == 2).int() | |
assert 1000 == hc.sum() | |
assert 1000 == hi.sum() | |
assert 1000 == hm.sum() | |
assert 1000 == mm.sum() | |
assert (hc + hi + hm + mm).sum() == self.categories.shape[0] | |
chk = (torch.Tensor(self.categories) - hc - hi * 2 - hm * 3 - mm * 4) | |
assert 0 == chk.abs().sum(), chk | |
def read_score(self, root): | |
mat = self.loadmat( | |
os.path.join(root, "pyCIDErConsensus/consensus_pascal.mat")) | |
data = mat["triplets"][0] | |
self.labels = [] | |
self.references = [] | |
for i in range(len(self)): | |
votes = {} | |
refs = [] | |
for j in range(i * 48, (i + 1) * 48): | |
a,b,c,d = [x[0][0] for x in data[j]] | |
key = b[0].strip() if 1 == d else c[0].strip() | |
refs.append(a[0].strip()) | |
votes[key] = votes.get(key, 0) + 1 | |
assert 2 >= len(votes.keys()), votes | |
assert len(votes.keys()) > 0 | |
try: | |
vote_a = votes.get(self.data[i][1][0].strip(), 0) | |
vote_b = votes.get(self.data[i][2][0].strip(), 0) | |
except KeyError: | |
print("warning: data mismatch!") | |
print(f"a: {self.data[i][1][0].strip()}") | |
print(f"b: {self.data[i][2][0].strip()}") | |
print(votes) | |
exit() | |
# Ties are broken randomly. | |
label = 0 if vote_a > vote_b + random.random() - .5 else 1 # a == bの場合は0.5の確率で0か1を選ぶ | |
self.labels.append(label) | |
self.references.append(refs) | |
def __len__(self): | |
return len(self.data) | |
def get_image_path(self, filename: str): | |
path = os.path.join(self.voc_path, "JPEGImages") | |
return os.path.join(path, filename) | |
def __getitem__(self, idx: int): | |
vid, a, b = [x[0] for x in self.data[idx]] | |
label = self.labels[idx] | |
img_path = self.get_image_path(vid) | |
a = a.strip() | |
b = b.strip() | |
references = self.references[idx] | |
category = self.categories[idx] | |
category_str = self.idx2cat[category] | |
return img_path, a, b, references, category_str, label | |
def sanity_check(detail=False): | |
# sanity check | |
dataset = Pascal50sDataset(root="pascal/", voc_path="pascal/VOCdevkit/VOC2010") | |
one_sample = dataset[0] | |
assert one_sample is not None | |
dprint = lambda *args, **kwargs: print(*args, **kwargs) if detail else None | |
for it, one_sample in enumerate(tqdm(dataset)): | |
dprint("="*20) | |
dprint("sample:",it) | |
dprint("="*20) | |
img_path, a, b, references, category, label = one_sample | |
assert os.path.exists(img_path) | |
dprint("img_path:", img_path) | |
dprint("a:", a) | |
dprint("b:", b) | |
dprint("references:", references) | |
dprint("category:", category) | |
dprint("label:", label) | |
if __name__ == "__main__": | |
sanity_check(detail=False) | |