hugaagg's picture
Upload folder using huggingface_hub
2ecc7ab verified
import cv2
import numpy as np
import random
import torch
from torchvision.transforms.functional import rgb_to_grayscale
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees)."""
hflip = hflip and random.random() < 0.5
vflip = rotation and random.random() < 0.5
rot90 = rotation and random.random() < 0.5
def _augment(img):
if hflip: cv2.flip(img, 1, img)
if vflip: cv2.flip(img, 0, img)
if rot90: img = img.transpose(1, 0, 2)
return img
if not isinstance(imgs, list): imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1: imgs = imgs[0]
return imgs
def mod_crop(img, scale):
"""Mod crop images, used during testing."""
img = img.copy()
if img.ndim in (2, 3):
h, w = img.shape[0], img.shape[1]
h_remainder, w_remainder = h % scale, w % scale
img = img[:h - h_remainder, :w - w_remainder, ...]
else:
raise ValueError(f'Wrong img ndim: {img.ndim}.')
return img
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
"""Paired random crop. (这是报错缺失的函数)"""
if not isinstance(img_gts, list): img_gts = [img_gts]
if not isinstance(img_lqs, list): img_lqs = [img_lqs]
h_lq, w_lq, _ = img_lqs[0].shape
h_gt, w_gt, _ = img_gts[0].shape
lq_patch_size = gt_patch_size // scale
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x multiplication of LQ ({h_lq}, {w_lq}).')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ({lq_patch_size}, {lq_patch_size}).')
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
top_gt, left_gt = int(top * scale), int(left * scale)
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
if len(img_gts) == 1: img_gts = img_gts[0]
if len(img_lqs) == 1: img_lqs = img_lqs[0]
return img_gts, img_lqs