|
import torch |
|
import torchvision.transforms as T |
|
from tabulate import tabulate |
|
from torch.utils.data import Dataset as TorchDataset |
|
|
|
from dassl.utils import read_image |
|
|
|
from .datasets import build_dataset |
|
from .samplers import build_sampler |
|
from .transforms import INTERPOLATION_MODES, build_transform |
|
|
|
|
|
def build_data_loader( |
|
cfg, |
|
sampler_type="SequentialSampler", |
|
data_source=None, |
|
batch_size=64, |
|
n_domain=0, |
|
n_ins=2, |
|
tfm=None, |
|
is_train=True, |
|
dataset_wrapper=None |
|
): |
|
|
|
sampler = build_sampler( |
|
sampler_type, |
|
cfg=cfg, |
|
data_source=data_source, |
|
batch_size=batch_size, |
|
n_domain=n_domain, |
|
n_ins=n_ins |
|
) |
|
|
|
if dataset_wrapper is None: |
|
dataset_wrapper = DatasetWrapper |
|
|
|
|
|
data_loader = torch.utils.data.DataLoader( |
|
dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train), |
|
batch_size=batch_size, |
|
sampler=sampler, |
|
num_workers=cfg.DATALOADER.NUM_WORKERS, |
|
drop_last=is_train and len(data_source) >= batch_size, |
|
pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA) |
|
) |
|
assert len(data_loader) > 0 |
|
|
|
return data_loader |
|
|
|
|
|
class DataManager: |
|
|
|
def __init__( |
|
self, |
|
cfg, |
|
custom_tfm_train=None, |
|
custom_tfm_test=None, |
|
dataset_wrapper=None |
|
): |
|
|
|
dataset = build_dataset(cfg) |
|
|
|
|
|
if custom_tfm_train is None: |
|
tfm_train = build_transform(cfg, is_train=True) |
|
else: |
|
print("* Using custom transform for training") |
|
tfm_train = custom_tfm_train |
|
|
|
if custom_tfm_test is None: |
|
tfm_test = build_transform(cfg, is_train=False) |
|
else: |
|
print("* Using custom transform for testing") |
|
tfm_test = custom_tfm_test |
|
|
|
|
|
train_loader_x = build_data_loader( |
|
cfg, |
|
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER, |
|
data_source=dataset.train_x, |
|
batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE, |
|
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN, |
|
n_ins=cfg.DATALOADER.TRAIN_X.N_INS, |
|
tfm=tfm_train, |
|
is_train=True, |
|
dataset_wrapper=dataset_wrapper |
|
) |
|
|
|
|
|
train_loader_u = None |
|
if dataset.train_u: |
|
sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER |
|
batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE |
|
n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN |
|
n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS |
|
|
|
if cfg.DATALOADER.TRAIN_U.SAME_AS_X: |
|
sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER |
|
batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE |
|
n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN |
|
n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS |
|
|
|
train_loader_u = build_data_loader( |
|
cfg, |
|
sampler_type=sampler_type_, |
|
data_source=dataset.train_u, |
|
batch_size=batch_size_, |
|
n_domain=n_domain_, |
|
n_ins=n_ins_, |
|
tfm=tfm_train, |
|
is_train=True, |
|
dataset_wrapper=dataset_wrapper |
|
) |
|
|
|
|
|
val_loader = None |
|
if dataset.val: |
|
val_loader = build_data_loader( |
|
cfg, |
|
sampler_type=cfg.DATALOADER.TEST.SAMPLER, |
|
data_source=dataset.val, |
|
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE, |
|
tfm=tfm_test, |
|
is_train=False, |
|
dataset_wrapper=dataset_wrapper |
|
) |
|
|
|
|
|
test_loader = build_data_loader( |
|
cfg, |
|
sampler_type=cfg.DATALOADER.TEST.SAMPLER, |
|
data_source=dataset.test, |
|
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE, |
|
tfm=tfm_test, |
|
is_train=False, |
|
dataset_wrapper=dataset_wrapper |
|
) |
|
|
|
|
|
self._num_classes = dataset.num_classes |
|
self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS) |
|
self._lab2cname = dataset.lab2cname |
|
|
|
|
|
self.dataset = dataset |
|
self.train_loader_x = train_loader_x |
|
self.train_loader_u = train_loader_u |
|
self.val_loader = val_loader |
|
self.test_loader = test_loader |
|
|
|
if cfg.VERBOSE: |
|
self.show_dataset_summary(cfg) |
|
|
|
@property |
|
def num_classes(self): |
|
return self._num_classes |
|
|
|
@property |
|
def num_source_domains(self): |
|
return self._num_source_domains |
|
|
|
@property |
|
def lab2cname(self): |
|
return self._lab2cname |
|
|
|
def show_dataset_summary(self, cfg): |
|
dataset_name = cfg.DATASET.NAME |
|
source_domains = cfg.DATASET.SOURCE_DOMAINS |
|
target_domains = cfg.DATASET.TARGET_DOMAINS |
|
|
|
table = [] |
|
table.append(["Dataset", dataset_name]) |
|
if source_domains: |
|
table.append(["Source", source_domains]) |
|
if target_domains: |
|
table.append(["Target", target_domains]) |
|
table.append(["# classes", f"{self.num_classes:,}"]) |
|
table.append(["# train_x", f"{len(self.dataset.train_x):,}"]) |
|
if self.dataset.train_u: |
|
table.append(["# train_u", f"{len(self.dataset.train_u):,}"]) |
|
if self.dataset.val: |
|
table.append(["# val", f"{len(self.dataset.val):,}"]) |
|
table.append(["# test", f"{len(self.dataset.test):,}"]) |
|
|
|
print(tabulate(table)) |
|
|
|
|
|
class DatasetWrapper(TorchDataset): |
|
|
|
def __init__(self, cfg, data_source, transform=None, is_train=False): |
|
self.cfg = cfg |
|
self.data_source = data_source |
|
self.transform = transform |
|
self.is_train = is_train |
|
|
|
self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1 |
|
self.return_img0 = cfg.DATALOADER.RETURN_IMG0 |
|
|
|
if self.k_tfm > 1 and transform is None: |
|
raise ValueError( |
|
"Cannot augment the image {} times " |
|
"because transform is None".format(self.k_tfm) |
|
) |
|
|
|
|
|
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION] |
|
to_tensor = [] |
|
to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)] |
|
to_tensor += [T.ToTensor()] |
|
if "normalize" in cfg.INPUT.TRANSFORMS: |
|
normalize = T.Normalize( |
|
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD |
|
) |
|
to_tensor += [normalize] |
|
self.to_tensor = T.Compose(to_tensor) |
|
|
|
def __len__(self): |
|
return len(self.data_source) |
|
|
|
def __getitem__(self, idx): |
|
item = self.data_source[idx] |
|
|
|
output = { |
|
"label": item.label, |
|
"domain": item.domain, |
|
"impath": item.impath, |
|
"index": idx |
|
} |
|
|
|
img0 = read_image(item.impath) |
|
|
|
if self.transform is not None: |
|
if isinstance(self.transform, (list, tuple)): |
|
for i, tfm in enumerate(self.transform): |
|
img = self._transform_image(tfm, img0) |
|
keyname = "img" |
|
if (i + 1) > 1: |
|
keyname += str(i + 1) |
|
output[keyname] = img |
|
else: |
|
img = self._transform_image(self.transform, img0) |
|
output["img"] = img |
|
else: |
|
output["img"] = img0 |
|
|
|
if self.return_img0: |
|
output["img0"] = self.to_tensor(img0) |
|
|
|
return output |
|
|
|
def _transform_image(self, tfm, img0): |
|
img_list = [] |
|
|
|
for k in range(self.k_tfm): |
|
img_list.append(tfm(img0)) |
|
|
|
img = img_list |
|
if len(img) == 1: |
|
img = img[0] |
|
|
|
return img |
|
|