| import cv2 | |
| import numpy as np | |
| import random | |
| import torch | |
| from torchvision.transforms.functional import rgb_to_grayscale | |
| def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): | |
| """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).""" | |
| hflip = hflip and random.random() < 0.5 | |
| vflip = rotation and random.random() < 0.5 | |
| rot90 = rotation and random.random() < 0.5 | |
| def _augment(img): | |
| if hflip: cv2.flip(img, 1, img) | |
| if vflip: cv2.flip(img, 0, img) | |
| if rot90: img = img.transpose(1, 0, 2) | |
| return img | |
| if not isinstance(imgs, list): imgs = [imgs] | |
| imgs = [_augment(img) for img in imgs] | |
| if len(imgs) == 1: imgs = imgs[0] | |
| return imgs | |
| def mod_crop(img, scale): | |
| """Mod crop images, used during testing.""" | |
| img = img.copy() | |
| if img.ndim in (2, 3): | |
| h, w = img.shape[0], img.shape[1] | |
| h_remainder, w_remainder = h % scale, w % scale | |
| img = img[:h - h_remainder, :w - w_remainder, ...] | |
| else: | |
| raise ValueError(f'Wrong img ndim: {img.ndim}.') | |
| return img | |
| def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None): | |
| """Paired random crop. (这是报错缺失的函数)""" | |
| if not isinstance(img_gts, list): img_gts = [img_gts] | |
| if not isinstance(img_lqs, list): img_lqs = [img_lqs] | |
| h_lq, w_lq, _ = img_lqs[0].shape | |
| h_gt, w_gt, _ = img_gts[0].shape | |
| lq_patch_size = gt_patch_size // scale | |
| if h_gt != h_lq * scale or w_gt != w_lq * scale: | |
| raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x multiplication of LQ ({h_lq}, {w_lq}).') | |
| if h_lq < lq_patch_size or w_lq < lq_patch_size: | |
| raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ({lq_patch_size}, {lq_patch_size}).') | |
| top = random.randint(0, h_lq - lq_patch_size) | |
| left = random.randint(0, w_lq - lq_patch_size) | |
| img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] | |
| top_gt, left_gt = int(top * scale), int(left * scale) | |
| img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] | |
| if len(img_gts) == 1: img_gts = img_gts[0] | |
| if len(img_lqs) == 1: img_lqs = img_lqs[0] | |
| return img_gts, img_lqs |