from abc import abstractmethod import torchvision.transforms as transforms class TransformsConfig(object): def __init__(self, opts): self.opts = opts @abstractmethod def get_transforms(self): pass class EncodeTransforms(TransformsConfig): def __init__(self, opts): super(EncodeTransforms, self).__init__(opts) def get_transforms(self): transforms_dict = { 'transform_gt_train': transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomHorizontalFlip(0.5), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 'transform_source': None, 'transform_test': transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 'transform_inference': transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) } return transforms_dict class CarsEncodeTransforms(TransformsConfig): def __init__(self, opts): super(CarsEncodeTransforms, self).__init__(opts) def get_transforms(self): transforms_dict = { 'transform_gt_train': transforms.Compose([ transforms.Resize((192, 256)), transforms.RandomHorizontalFlip(0.5), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 'transform_source': None, 'transform_test': transforms.Compose([ transforms.Resize((192, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 'transform_inference': transforms.Compose([ transforms.Resize((192, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) } return transforms_dict