| import random | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms.functional as F | |
| from torchvision.transforms import RandomCrop, InterpolationMode | |
| class CustomRandomResize(nn.Module): | |
| def __init__(self, scale=(0.5, 2.0), interpolation=InterpolationMode.BILINEAR): | |
| super().__init__() | |
| self.min_scale, self.max_scale = min(scale), max(scale) | |
| self.interpolation = interpolation | |
| def forward(self, img): | |
| if isinstance(img, torch.Tensor): | |
| height, width = img.shape[:2] | |
| else: | |
| width, height = img.size | |
| scale = random.uniform(self.min_scale, self.max_scale) | |
| new_size = [int(height * scale), int(width * scale)] | |
| img = F.resize(img, new_size, self.interpolation) | |
| return img | |
| class CustomRandomCrop(RandomCrop): | |
| def forward(self, img): | |
| """ | |
| Args: | |
| img (PIL Image or Tensor): Image to be cropped. | |
| Returns: | |
| PIL Image or Tensor: Cropped image. | |
| """ | |
| width, height = F.get_image_size(img) | |
| tar_h, tar_w = self.size | |
| tar_h = min(tar_h, height) | |
| tar_w = min(tar_w, width) | |
| i, j, h, w = self.get_params(img, (tar_h, tar_w)) | |
| return F.crop(img, i, j, h, w) | |