|
import torch |
|
import torchvision |
|
|
|
import math |
|
import cv2 |
|
import numpy as np |
|
from scipy.ndimage import rotate |
|
|
|
|
|
class RandCrop(object): |
|
def __init__(self, crop_size, scale): |
|
|
|
assert isinstance(crop_size, (int, tuple)) |
|
if isinstance(crop_size, int): |
|
self.crop_size = (crop_size, crop_size) |
|
else: |
|
assert len(crop_size) == 2 |
|
self.crop_size = crop_size |
|
|
|
self.scale = scale |
|
|
|
def __call__(self, sample): |
|
|
|
img_LQ, img_GT = sample['img_LQ'], sample['img_GT'] |
|
|
|
h, w, c = img_LQ.shape |
|
new_h, new_w = self.crop_size |
|
top = np.random.randint(0, h - new_h) |
|
left = np.random.randint(0, w - new_w) |
|
img_LQ_crop = img_LQ[top: top+new_h, left: left+new_w, :] |
|
|
|
h, w, c = img_GT.shape |
|
top = np.random.randint(0, h - self.scale*new_h) |
|
left = np.random.randint(0, w - self.scale*new_w) |
|
img_GT_crop = img_GT[top: top + self.scale*new_h, left: left + self.scale*new_w, :] |
|
|
|
sample = {'img_LQ': img_LQ_crop, 'img_GT': img_GT_crop} |
|
return sample |
|
|
|
|
|
class RandRotate(object): |
|
def __call__(self, sample): |
|
|
|
img_LQ, img_GT = sample['img_LQ'], sample['img_GT'] |
|
|
|
prob_rotate = np.random.random() |
|
if prob_rotate < 0.25: |
|
img_LQ = rotate(img_LQ, 90).copy() |
|
img_GT = rotate(img_GT, 90).copy() |
|
elif prob_rotate < 0.5: |
|
img_LQ = rotate(img_LQ, 90).copy() |
|
img_GT = rotate(img_GT, 90).copy() |
|
elif prob_rotate < 0.75: |
|
img_LQ = rotate(img_LQ, 90).copy() |
|
img_GT = rotate(img_GT, 90).copy() |
|
|
|
sample = {'img_LQ': img_LQ, 'img_GT': img_GT} |
|
return sample |
|
|
|
|
|
class RandHorizontalFlip(object): |
|
def __call__(self, sample): |
|
|
|
img_LQ, img_GT = sample['img_LQ'], sample['img_GT'] |
|
|
|
prob_lr = np.random.random() |
|
if prob_lr < 0.5: |
|
img_LQ = np.fliplr(img_LQ).copy() |
|
img_GT = np.fliplr(img_GT).copy() |
|
|
|
sample = {'img_LQ': img_LQ, 'img_GT': img_GT} |
|
return sample |
|
|
|
|
|
class ToTensor(object): |
|
def __call__(self, sample): |
|
|
|
img_LQ, img_GT = sample['img_LQ'], sample['img_GT'] |
|
|
|
img_LQ = img_LQ.transpose((2, 0, 1)) |
|
img_GT = img_GT.transpose((2, 0, 1)) |
|
|
|
img_LQ = torch.from_numpy(img_LQ) |
|
img_GT = torch.from_numpy(img_GT) |
|
|
|
sample = {'img_LQ': img_LQ, 'img_GT': img_GT} |
|
return sample |
|
|
|
|
|
class VGG19PerceptualLoss(torch.nn.Module): |
|
def __init__(self, feature_layer=35): |
|
super(VGG19PerceptualLoss, self).__init__() |
|
model = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT) |
|
self.features = torch.nn.Sequential(*list(model.features.children())[:feature_layer]).eval() |
|
|
|
for name, param in self.features.named_parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, source, target): |
|
vgg_loss = torch.nn.functional.l1_loss(self.features(source), self.features(target)) |
|
|
|
return vgg_loss |
|
|
|
|
|
class RandCrop_pair(object): |
|
def __init__(self, crop_size, scale): |
|
|
|
assert isinstance(crop_size, (int, tuple)) |
|
if isinstance(crop_size, int): |
|
self.crop_size = (crop_size, crop_size) |
|
else: |
|
assert len(crop_size) == 2 |
|
self.crop_size = crop_size |
|
|
|
self.scale = scale |
|
|
|
def __call__(self, sample): |
|
|
|
img_LQ, img_GT = sample['img_LQ'], sample['img_GT'] |
|
|
|
h, w, c = img_LQ.shape |
|
new_h, new_w = self.crop_size |
|
top = np.random.randint(0, h - new_h) |
|
left = np.random.randint(0, w - new_w) |
|
img_LQ_crop = img_LQ[top: top+new_h, left: left+new_w, :] |
|
|
|
h, w, c = img_GT.shape |
|
top = self.scale*top |
|
left = self.scale*left |
|
img_GT_crop = img_GT[top: top + self.scale*new_h, left: left + self.scale*new_w, :] |
|
|
|
sample = {'img_LQ': img_LQ_crop, 'img_GT': img_GT_crop} |
|
return sample |