Fabrice-TIERCELIN commited on
Commit
bc69feb
·
verified ·
1 Parent(s): a902bc0

Delete clipseg/datasets/pascal_zeroshot.py

Browse files
Files changed (1) hide show
  1. clipseg/datasets/pascal_zeroshot.py +0 -60
clipseg/datasets/pascal_zeroshot.py DELETED
@@ -1,60 +0,0 @@
1
- from os.path import expanduser
2
- import torch
3
- import json
4
- import torchvision
5
- from general_utils import get_from_repository
6
- from general_utils import log
7
- from torchvision import transforms
8
-
9
- PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'],
10
- ['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'],
11
- ['chair.n.01', 'pot_plant.n.01']]
12
-
13
-
14
- class PascalZeroShot(object):
15
-
16
- def __init__(self, split, n_unseen, image_size=224) -> None:
17
- super().__init__()
18
-
19
- import sys
20
- sys.path.append('third_party/JoEm')
21
- from third_party.JoEm.data_loader.dataset import VOCSegmentation
22
- from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
23
-
24
- self.pascal_classes = VOC
25
- self.image_size = image_size
26
-
27
- self.transform = transforms.Compose([
28
- transforms.Resize((image_size, image_size)),
29
- ])
30
-
31
- if split == 'train':
32
- self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
33
- split=split, transform=True, transform_args=dict(base_size=312, crop_size=312),
34
- ignore_bg=False, ignore_unseen=False, remv_unseen_img=True)
35
- elif split == 'val':
36
- self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen),
37
- split=split, transform=False,
38
- ignore_bg=False, ignore_unseen=False)
39
-
40
- self.unseen_idx = get_unseen_idx(n_unseen)
41
-
42
- def __len__(self):
43
- return len(self.voc)
44
-
45
- def __getitem__(self, i):
46
-
47
- sample = self.voc[i]
48
- label = sample['label'].long()
49
- all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255]
50
- class_indices = [l for l in all_labels]
51
- class_names = [self.pascal_classes[l] for l in all_labels]
52
-
53
- image = self.transform(sample['image'])
54
-
55
- label = transforms.Resize((self.image_size, self.image_size),
56
- interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0]
57
-
58
- return (image,), (label, )
59
-
60
-