import torch import numpy as np class Normalize(object): def __init__(self, mean, var): self.mean = mean self.var = var def __call__(self, sample): if isinstance(sample, dict): img = sample['img'] gt = sample['gt'] img = (img - self.mean) / self.var sample = {'img': img, 'gt': gt} else: sample = (sample - self.mean) / self.var return sample class RandHorizontalFlip(object): def __init__(self, prob_aug): self.prob_aug = prob_aug def __call__(self, sample): p_aug = np.array([self.prob_aug, 1 - self.prob_aug]) prob_lr = np.random.choice([1, 0], p=p_aug.ravel()) if isinstance(sample, dict): img = sample['img'] gt = sample['gt'] if prob_lr > 0.5: img = np.fliplr(img).copy() sample = {'img': img, 'gt': gt} else: if prob_lr > 0.5: sample = np.fliplr(sample).copy() return sample class ToTensor(object): def __init__(self): pass def __call__(self, sample): if isinstance(sample, dict): img = sample['img'] gt = sample['gt'] img = torch.from_numpy(img).type(torch.FloatTensor) gt = torch.from_numpy(gt).type(torch.FloatTensor) sample = {'img': img, 'gt': gt} else: sample = torch.from_numpy(sample).type(torch.FloatTensor) return sample