|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.autograd import Variable |
|
from torchvision.models import vgg19 |
|
import torchvision.transforms as transforms |
|
from torch.utils.data import DataLoader, Dataset |
|
from torchvision.utils import save_image, make_grid |
|
from torchvision.transforms import ToTensor |
|
|
|
import numpy as np |
|
import cv2 |
|
import glob |
|
import random |
|
from PIL import Image |
|
from tqdm import tqdm |
|
|
|
|
|
|
|
from opt import opt |
|
|
|
|
|
class ImageDataset(Dataset): |
|
@torch.no_grad() |
|
def __init__(self, train_lr_paths, degrade_hr_paths, train_hr_paths): |
|
|
|
|
|
|
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
] |
|
) |
|
|
|
self.files_lr = train_lr_paths |
|
self.files_degrade_hr = degrade_hr_paths |
|
self.files_hr = train_hr_paths |
|
|
|
assert(len(self.files_lr) == len(self.files_hr)) |
|
assert(len(self.files_lr) == len(self.files_degrade_hr)) |
|
|
|
|
|
def augment(self, imgs, hflip=True, rotation=True): |
|
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). |
|
|
|
All the images in the list use the same augmentation. |
|
|
|
Args: |
|
imgs (list[ndarray] | ndarray): Images to be augmented. If the input |
|
is an ndarray, it will be transformed to a list. |
|
hflip (bool): Horizontal flip. Default: True. |
|
rotation (bool): Rotation. Default: True. |
|
|
|
Returns: |
|
imgs (list[ndarray] | ndarray): Augmented images and flows. If returned |
|
results only have one element, just return ndarray. |
|
|
|
""" |
|
hflip = hflip and random.random() < 0.5 |
|
vflip = rotation and random.random() < 0.5 |
|
rot90 = rotation and random.random() < 0.5 |
|
|
|
def _augment(img): |
|
if hflip: |
|
cv2.flip(img, 1, img) |
|
if vflip: |
|
cv2.flip(img, 0, img) |
|
if rot90: |
|
img = img.transpose(1, 0, 2) |
|
return img |
|
|
|
|
|
if not isinstance(imgs, list): |
|
imgs = [imgs] |
|
|
|
imgs = [_augment(img) for img in imgs] |
|
if len(imgs) == 1: |
|
imgs = imgs[0] |
|
|
|
|
|
return imgs |
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
|
|
img_lr = cv2.imread(self.files_lr[index % len(self.files_lr)]) |
|
img_degrade_hr = cv2.imread(self.files_degrade_hr[index % len(self.files_degrade_hr)]) |
|
img_hr = cv2.imread(self.files_hr[index % len(self.files_hr)]) |
|
|
|
|
|
if random.random() < opt["augment_prob"]: |
|
img_lr, img_degrade_hr, img_hr = self.augment([img_lr, img_degrade_hr, img_hr]) |
|
|
|
|
|
img_lr = self.transform(img_lr) |
|
img_degrade_hr = self.transform(img_degrade_hr) |
|
img_hr = self.transform(img_hr) |
|
|
|
|
|
return {"lr": img_lr, "degrade_hr": img_degrade_hr, "hr": img_hr} |
|
|
|
def __len__(self): |
|
assert(len(self.files_hr) == len(self.files_lr)) |
|
return len(self.files_hr) |
|
|