File size: 8,111 Bytes
8c6b5ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 |
import torch
import torchvision.transforms as T
from tabulate import tabulate
from torch.utils.data import Dataset as TorchDataset
from dassl.utils import read_image
from .datasets import build_dataset
from .samplers import build_sampler
from .transforms import INTERPOLATION_MODES, build_transform
def build_data_loader(
cfg,
sampler_type="SequentialSampler",
data_source=None,
batch_size=64,
n_domain=0,
n_ins=2,
tfm=None,
is_train=True,
dataset_wrapper=None
):
# Build sampler
sampler = build_sampler(
sampler_type,
cfg=cfg,
data_source=data_source,
batch_size=batch_size,
n_domain=n_domain,
n_ins=n_ins
)
if dataset_wrapper is None:
dataset_wrapper = DatasetWrapper
# Build data loader
data_loader = torch.utils.data.DataLoader(
dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train),
batch_size=batch_size,
sampler=sampler,
num_workers=cfg.DATALOADER.NUM_WORKERS,
drop_last=is_train and len(data_source) >= batch_size,
pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA)
)
assert len(data_loader) > 0
return data_loader
class DataManager:
def __init__(
self,
cfg,
custom_tfm_train=None,
custom_tfm_test=None,
dataset_wrapper=None
):
# Load dataset
dataset = build_dataset(cfg)
# Build transform
if custom_tfm_train is None:
tfm_train = build_transform(cfg, is_train=True)
else:
print("* Using custom transform for training")
tfm_train = custom_tfm_train
if custom_tfm_test is None:
tfm_test = build_transform(cfg, is_train=False)
else:
print("* Using custom transform for testing")
tfm_test = custom_tfm_test
# Build train_loader_x
train_loader_x = build_data_loader(
cfg,
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
data_source=dataset.train_x,
batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
tfm=tfm_train,
is_train=True,
dataset_wrapper=dataset_wrapper
)
# Build train_loader_u
train_loader_u = None
if dataset.train_u:
sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER
batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE
n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN
n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS
if cfg.DATALOADER.TRAIN_U.SAME_AS_X:
sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER
batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN
n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS
train_loader_u = build_data_loader(
cfg,
sampler_type=sampler_type_,
data_source=dataset.train_u,
batch_size=batch_size_,
n_domain=n_domain_,
n_ins=n_ins_,
tfm=tfm_train,
is_train=True,
dataset_wrapper=dataset_wrapper
)
# Build val_loader
val_loader = None
if dataset.val:
val_loader = build_data_loader(
cfg,
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
data_source=dataset.val,
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
tfm=tfm_test,
is_train=False,
dataset_wrapper=dataset_wrapper
)
# Build test_loader
test_loader = build_data_loader(
cfg,
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
data_source=dataset.test,
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
tfm=tfm_test,
is_train=False,
dataset_wrapper=dataset_wrapper
)
# Attributes
self._num_classes = dataset.num_classes
self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS)
self._lab2cname = dataset.lab2cname
# Dataset and data-loaders
self.dataset = dataset
self.train_loader_x = train_loader_x
self.train_loader_u = train_loader_u
self.val_loader = val_loader
self.test_loader = test_loader
if cfg.VERBOSE:
self.show_dataset_summary(cfg)
@property
def num_classes(self):
return self._num_classes
@property
def num_source_domains(self):
return self._num_source_domains
@property
def lab2cname(self):
return self._lab2cname
def show_dataset_summary(self, cfg):
dataset_name = cfg.DATASET.NAME
source_domains = cfg.DATASET.SOURCE_DOMAINS
target_domains = cfg.DATASET.TARGET_DOMAINS
table = []
table.append(["Dataset", dataset_name])
if source_domains:
table.append(["Source", source_domains])
if target_domains:
table.append(["Target", target_domains])
table.append(["# classes", f"{self.num_classes:,}"])
table.append(["# train_x", f"{len(self.dataset.train_x):,}"])
if self.dataset.train_u:
table.append(["# train_u", f"{len(self.dataset.train_u):,}"])
if self.dataset.val:
table.append(["# val", f"{len(self.dataset.val):,}"])
table.append(["# test", f"{len(self.dataset.test):,}"])
print(tabulate(table))
class DatasetWrapper(TorchDataset):
def __init__(self, cfg, data_source, transform=None, is_train=False):
self.cfg = cfg
self.data_source = data_source
self.transform = transform # accept list (tuple) as input
self.is_train = is_train
# Augmenting an image K>1 times is only allowed during training
self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1
self.return_img0 = cfg.DATALOADER.RETURN_IMG0
if self.k_tfm > 1 and transform is None:
raise ValueError(
"Cannot augment the image {} times "
"because transform is None".format(self.k_tfm)
)
# Build transform that doesn't apply any data augmentation
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
to_tensor = []
to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
to_tensor += [T.ToTensor()]
if "normalize" in cfg.INPUT.TRANSFORMS:
normalize = T.Normalize(
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
)
to_tensor += [normalize]
self.to_tensor = T.Compose(to_tensor)
def __len__(self):
return len(self.data_source)
def __getitem__(self, idx):
item = self.data_source[idx]
output = {
"label": item.label,
"domain": item.domain,
"impath": item.impath,
"index": idx
}
img0 = read_image(item.impath)
if self.transform is not None:
if isinstance(self.transform, (list, tuple)):
for i, tfm in enumerate(self.transform):
img = self._transform_image(tfm, img0)
keyname = "img"
if (i + 1) > 1:
keyname += str(i + 1)
output[keyname] = img
else:
img = self._transform_image(self.transform, img0)
output["img"] = img
else:
output["img"] = img0
if self.return_img0:
output["img0"] = self.to_tensor(img0) # without any augmentation
return output
def _transform_image(self, tfm, img0):
img_list = []
for k in range(self.k_tfm):
img_list.append(tfm(img0))
img = img_list
if len(img) == 1:
img = img[0]
return img
|