APISR / architecture /dataset.py
HikariDawn's picture
feat: initial push
561c629
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import vgg19
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image, make_grid
from torchvision.transforms import ToTensor
import numpy as np
import cv2
import glob
import random
from PIL import Image
from tqdm import tqdm
# from degradation.degradation_main import degredate_process, preparation
from opt import opt
class ImageDataset(Dataset):
@torch.no_grad()
def __init__(self, train_lr_paths, degrade_hr_paths, train_hr_paths):
# print("low_res path sample is ", train_lr_paths[0])
# print(train_hr_paths[0])
# hr_height, hr_width = hr_shape
self.transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
self.files_lr = train_lr_paths
self.files_degrade_hr = degrade_hr_paths
self.files_hr = train_hr_paths
assert(len(self.files_lr) == len(self.files_hr))
assert(len(self.files_lr) == len(self.files_degrade_hr))
def augment(self, imgs, hflip=True, rotation=True):
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
All the images in the list use the same augmentation.
Args:
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
is an ndarray, it will be transformed to a list.
hflip (bool): Horizontal flip. Default: True.
rotation (bool): Rotation. Default: True.
Returns:
imgs (list[ndarray] | ndarray): Augmented images and flows. If returned
results only have one element, just return ndarray.
"""
hflip = hflip and random.random() < 0.5
vflip = rotation and random.random() < 0.5
rot90 = rotation and random.random() < 0.5
def _augment(img):
if hflip: # horizontal
cv2.flip(img, 1, img)
if vflip: # vertical
cv2.flip(img, 0, img)
if rot90:
img = img.transpose(1, 0, 2)
return img
if not isinstance(imgs, list):
imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1:
imgs = imgs[0]
return imgs
def __getitem__(self, index):
# Read File
img_lr = cv2.imread(self.files_lr[index % len(self.files_lr)]) # Should be BGR
img_degrade_hr = cv2.imread(self.files_degrade_hr[index % len(self.files_degrade_hr)])
img_hr = cv2.imread(self.files_hr[index % len(self.files_hr)])
# Augmentation
if random.random() < opt["augment_prob"]:
img_lr, img_degrade_hr, img_hr = self.augment([img_lr, img_degrade_hr, img_hr])
# Transform to Tensor
img_lr = self.transform(img_lr)
img_degrade_hr = self.transform(img_degrade_hr)
img_hr = self.transform(img_hr) # ToTensor() is already in the range [0, 1]
return {"lr": img_lr, "degrade_hr": img_degrade_hr, "hr": img_hr}
def __len__(self):
assert(len(self.files_hr) == len(self.files_lr))
return len(self.files_hr)