from os.path import expanduser import torch import json import torchvision from general_utils import get_from_repository from general_utils import log from torchvision import transforms PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'], ['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'], ['chair.n.01', 'pot_plant.n.01']] class PascalZeroShot(object): def __init__(self, split, n_unseen, image_size=224) -> None: super().__init__() import sys sys.path.append('third_party/JoEm') from third_party.JoEm.data_loader.dataset import VOCSegmentation from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC self.pascal_classes = VOC self.image_size = image_size self.transform = transforms.Compose([ transforms.Resize((image_size, image_size)), ]) if split == 'train': self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen), split=split, transform=True, transform_args=dict(base_size=312, crop_size=312), ignore_bg=False, ignore_unseen=False, remv_unseen_img=True) elif split == 'val': self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen), split=split, transform=False, ignore_bg=False, ignore_unseen=False) self.unseen_idx = get_unseen_idx(n_unseen) def __len__(self): return len(self.voc) def __getitem__(self, i): sample = self.voc[i] label = sample['label'].long() all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255] class_indices = [l for l in all_labels] class_names = [self.pascal_classes[l] for l in all_labels] image = self.transform(sample['image']) label = transforms.Resize((self.image_size, self.image_size), interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0] return (image,), (label, )