| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.autograd import Variable |
| from torchvision.models import vgg19 |
| import torchvision.transforms as transforms |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision.utils import save_image, make_grid |
| from torchvision.transforms import ToTensor |
|
|
| import numpy as np |
| import cv2 |
| import glob |
| import random |
| from PIL import Image |
| from tqdm import tqdm |
|
|
|
|
| |
| from opt import opt |
|
|
|
|
| class ImageDataset(Dataset): |
| @torch.no_grad() |
| def __init__(self, train_lr_paths, degrade_hr_paths, train_hr_paths): |
| |
| |
| |
| self.transform = transforms.Compose( |
| [ |
| transforms.ToTensor(), |
| ] |
| ) |
|
|
| self.files_lr = train_lr_paths |
| self.files_degrade_hr = degrade_hr_paths |
| self.files_hr = train_hr_paths |
|
|
| assert(len(self.files_lr) == len(self.files_hr)) |
| assert(len(self.files_lr) == len(self.files_degrade_hr)) |
|
|
|
|
| def augment(self, imgs, hflip=True, rotation=True): |
| """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). |
| |
| All the images in the list use the same augmentation. |
| |
| Args: |
| imgs (list[ndarray] | ndarray): Images to be augmented. If the input |
| is an ndarray, it will be transformed to a list. |
| hflip (bool): Horizontal flip. Default: True. |
| rotation (bool): Rotation. Default: True. |
| |
| Returns: |
| imgs (list[ndarray] | ndarray): Augmented images and flows. If returned |
| results only have one element, just return ndarray. |
| |
| """ |
| hflip = hflip and random.random() < 0.5 |
| vflip = rotation and random.random() < 0.5 |
| rot90 = rotation and random.random() < 0.5 |
|
|
| def _augment(img): |
| if hflip: |
| cv2.flip(img, 1, img) |
| if vflip: |
| cv2.flip(img, 0, img) |
| if rot90: |
| img = img.transpose(1, 0, 2) |
| return img |
|
|
|
|
| if not isinstance(imgs, list): |
| imgs = [imgs] |
| |
| imgs = [_augment(img) for img in imgs] |
| if len(imgs) == 1: |
| imgs = imgs[0] |
|
|
|
|
| return imgs |
| |
|
|
| def __getitem__(self, index): |
| |
| |
| img_lr = cv2.imread(self.files_lr[index % len(self.files_lr)]) |
| img_degrade_hr = cv2.imread(self.files_degrade_hr[index % len(self.files_degrade_hr)]) |
| img_hr = cv2.imread(self.files_hr[index % len(self.files_hr)]) |
|
|
| |
| if random.random() < opt["augment_prob"]: |
| img_lr, img_degrade_hr, img_hr = self.augment([img_lr, img_degrade_hr, img_hr]) |
| |
| |
| img_lr = self.transform(img_lr) |
| img_degrade_hr = self.transform(img_degrade_hr) |
| img_hr = self.transform(img_hr) |
|
|
|
|
| return {"lr": img_lr, "degrade_hr": img_degrade_hr, "hr": img_hr} |
| |
| def __len__(self): |
| assert(len(self.files_hr) == len(self.files_lr)) |
| return len(self.files_hr) |
|
|