File size: 2,327 Bytes
803ef9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import os
import copy
import numpy as np
from PIL import Image
from os.path import join
from itertools import chain
from collections import defaultdict
import torch
import torch.utils.data as data
from torchvision import transforms
DATA_ROOTS = 'data/DTD'
# wget https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz
# tar -xvzf dtd-r1.0.1.tar.gz
class DTD(data.Dataset):
def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None):
super().__init__()
self.root = root
self.train = train
self.image_transforms = image_transforms
paths, labels = self.load_images()
self.paths, self.labels = paths, labels
def load_images(self):
if self.train:
train_info_path = os.path.join(self.root, 'labels', 'train1.txt')
with open(train_info_path, 'r') as f:
train_info = [line.split('\n')[0] for line in f.readlines()]
val_info_path = os.path.join(self.root, 'labels', 'val1.txt')
with open(val_info_path, 'r') as f:
val_info = [line.split('\n')[0] for line in f.readlines()]
split_info = train_info + val_info
else:
test_info_path = os.path.join(self.root, 'labels', 'test1.txt')
with open(test_info_path, 'r') as f:
split_info = [line.split('\n')[0] for line in f.readlines()]
# pull out categoires from paths
categories = []
for row in split_info:
image_path = row
category = image_path.split('/')[0]
categories.append(category)
categories = sorted(list(set(categories)))
all_paths, all_labels = [], []
for row in split_info:
image_path = row
category = image_path.split('/')[0]
label = categories.index(category)
all_paths.append(join(self.root, 'images', image_path))
all_labels.append(label)
return all_paths, all_labels
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
label = self.labels[index]
image = Image.open(path).convert(mode='RGB')
if self.image_transforms:
image = self.image_transforms(image)
return image, label |