apolinario's picture
upload clipseg
48fa639
from os.path import expanduser
import torch
import json
import torchvision
from general_utils import get_from_repository
from general_utils import log
from torchvision import transforms
PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'],
['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'],
['chair.n.01', 'pot_plant.n.01']]
class PascalZeroShot(object):
def __init__(self, split, n_unseen, image_size=224) -> None:
super().__init__()
import sys
sys.path.append('third_party/JoEm')
from third_party.JoEm.data_loader.dataset import VOCSegmentation
from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
self.pascal_classes = VOC
self.image_size = image_size
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
])
if split == 'train':
self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
split=split, transform=True, transform_args=dict(base_size=312, crop_size=312),
ignore_bg=False, ignore_unseen=False, remv_unseen_img=True)
elif split == 'val':
self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
split=split, transform=False,
ignore_bg=False, ignore_unseen=False)
self.unseen_idx = get_unseen_idx(n_unseen)
def __len__(self):
return len(self.voc)
def __getitem__(self, i):
sample = self.voc[i]
label = sample['label'].long()
all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255]
class_indices = [l for l in all_labels]
class_names = [self.pascal_classes[l] for l in all_labels]
image = self.transform(sample['image'])
label = transforms.Resize((self.image_size, self.image_size),
interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0]
return (image,), (label, )