import os import random from pathlib import Path from PIL import Image import cv2 import ffmpeg import io import av import numpy as np import torch from torchvision.transforms.functional import normalize from basicsr.data.degradations import (random_add_gaussian_noise, random_mixed_kernels) from basicsr.data.transforms import augment from basicsr.utils import FileClient, get_root_logger, img2tensor, imfrombytes, scandir from basicsr.utils.registry import DATASET_REGISTRY from facelib.utils.face_restoration_helper import FaceAligner from torch.utils import data as data @DATASET_REGISTRY.register() class SingleVFHQDataset(data.Dataset): """Support for blind setting adopted in paper. We excludes the random scale compared to GFPGAN. This dataset is adopted in BasicVSR. The degradation order is blur+downsample+noise Note that we skip the low quality frames within the VFHQ clip. Directly read image by cv2. Generate LR images online. NOTE: The specific degradation order is blur-noise-downsample-crf-upsample The keys are generated from a meta info txt file. Key format: subfolder-name/clip-length/frame-name Key examples: "id00020#t0bbIRgKKzM#00381.txt#000.mp4/00000152/00000000" GT (gt): Ground-Truth; LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. Args: opt (dict): Config for train dataset. It contains the following keys: dataroot_gt (str): Data root path for gt. dataroot_clip_meta_info (srt): Data root path for meta info of each gt clip. global_meta_info_file (str): Path for global meta information file. io_backend (dict): IO backend type and other kwarg. num_frame (int): Window size for input frames. interval_list (list): Interval list for temporal augmentation. random_reverse (bool): Random reverse input frames. use_flip (bool): Use horizontal flips. use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). """ def __init__(self, opt): super(SingleVFHQDataset, self).__init__() self.opt = opt self.gt_root = Path(opt['dataroot_gt']) self.normalize = opt.get('normalize', False) self.need_align = opt.get('need_align', False) logger = get_root_logger() self.keys = [] with open(opt['global_meta_info_file'], 'r') as fin: for line in fin: real_clip_path = '/'.join(line.split('/')[:-1]) clip_length = line.split('/')[-1] clip_length = int(clip_length) self.keys.extend( [f'{real_clip_path}/{clip_length:08d}/{frame_idx:08d}' for frame_idx in range(int(clip_length))]) # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.is_lmdb = False if self.io_backend_opt['type'] == 'lmdb': self.is_lmdb = True self.io_backend_opt['db_paths'] = [self.gt_root] self.io_backend_opt['client_keys'] = ['gt'] if self.need_align: self.dataroot_meta_info = opt['dataroot_meta_info'] self.face_aligner = FaceAligner( upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True,) def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient( self.io_backend_opt.pop('type'), **self.io_backend_opt) key = self.keys[index] real_clip_path = '/'.join(key.split('/')[:-2]) clip_length = int(key.split('/')[-2]) frame_idx = int(key.split('/')[-1]) # get the neighboring GT frames flag = real_clip_path.split('/')[0] clip_name = real_clip_path.split('/')[-1] paths = sorted(list(scandir(os.path.join( self.gt_root, clip_name)))) assert len(paths) == clip_length, "Wrong length of frame list" img_gt_path = os.path.join( self.gt_root, clip_name, paths[frame_idx]) img_bytes = self.file_client.get(img_gt_path, 'gt') img_gt = imfrombytes(img_bytes, float32=True) # alignment if self.need_align: clip_info_path = os.path.join( self.dataroot_meta_info, f'{clip_name}.txt') clip_info = [] with open(clip_info_path, 'r', encoding='utf-8') as fin: for line in fin: line = line.strip() if line.startswith('0'): clip_info.append(line) landmarks_str = clip_info[frame_idx].split(' ')[1:] landmarks = np.array([float(x) for x in landmarks_str]).reshape(5, 2) self.face_aligner.clean_all() # align and warp each face img_gt = self.face_aligner.align_single_face(img_gt, landmarks) # augmentation - flip, rotate img_gt = augment(img_gt, self.opt['use_flip'], self.opt['use_rot']) img_in = img_gt # ------------- end --------------# img_in, img_gt = img2tensor([img_in, img_gt]) if self.normalize: normalize(img_in, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True) normalize(img_gt, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True) # img_lqs: (t, c, h, w) # img_gts: (t, c, h, w) # key: str return {'in': img_in, 'gt': img_gt, 'key': key} def __len__(self): return len(self.keys) @DATASET_REGISTRY.register() class VFHQDataset(data.Dataset): """Support for blind setting adopted in paper. We excludes the random scale compared to GFPGAN. This dataset is adopted in BasicVSR. The degradation order is blur+downsample+noise Note that we skip the low quality frames within the VFHQ clip. Directly read image by cv2. Generate LR images online. NOTE: The specific degradation order is blur-noise-downsample-crf-upsample The keys are generated from a meta info txt file. Key format: subfolder-name/clip-length/frame-name Key examples: "id00020#t0bbIRgKKzM#00381.txt#000.mp4/00000152/00000000" GT (gt): Ground-Truth; LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. Args: opt (dict): Config for train dataset. It contains the following keys: dataroot_gt (str): Data root path for gt. dataroot_clip_meta_info (srt): Data root path for meta info of each gt clip. global_meta_info_file (str): Path for global meta information file. io_backend (dict): IO backend type and other kwarg. num_frame (int): Window size for input frames. interval_list (list): Interval list for temporal augmentation. random_reverse (bool): Random reverse input frames. use_flip (bool): Use horizontal flips. use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). """ def __init__(self, opt): super(VFHQDataset, self).__init__() self.opt = opt self.gt_root = Path(opt['dataroot_gt']) self.num_frame = opt['num_frame'] self.scale = opt['scale'] self.need_align = opt.get('need_align', False) self.normalize = opt.get('normalize', False) self.keys = [] with open(opt['global_meta_info_file'], 'r') as fin: for line in fin: real_clip_path = '/'.join(line.split('/')[:-1]) clip_length = line.split('/')[-1] clip_length = int(clip_length) self.keys.extend( [f'{real_clip_path}/{clip_length:08d}/{frame_idx:08d}' for frame_idx in range(int(clip_length))]) # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.is_lmdb = False if self.io_backend_opt['type'] == 'lmdb': self.is_lmdb = True self.io_backend_opt['db_paths'] = [self.gt_root] self.io_backend_opt['client_keys'] = ['gt'] # temporal augmentation configs self.interval_list = opt['interval_list'] self.random_reverse = opt['random_reverse'] interval_str = ','.join(str(x) for x in opt['interval_list']) logger = get_root_logger() logger.info(f'Temporal augmentation interval list: [{interval_str}]; ' f'random reverse is {self.random_reverse}.') # degradations # blur self.blur_kernel_size = opt['blur_kernel_size'] self.kernel_list = opt['kernel_list'] self.kernel_prob = opt['kernel_prob'] self.blur_x_sigma = opt['blur_x_sigma'] self.blur_y_sigma = opt['blur_y_sigma'] # noise self.noise_range = opt['noise_range'] # resize self.resize_prob = opt['resize_prob'] # crf self.crf_range = opt['crf_range'] # codec self.vcodec = opt['vcodec'] self.vcodec_prob = opt['vcodec_prob'] logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, ' f'x_sigma: [{", ".join(map(str, self.blur_x_sigma))}], ' f'y_sigma: [{", ".join(map(str, self.blur_y_sigma))}], ') logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]') logger.info( f'CRF compression: [{", ".join(map(str, self.crf_range))}]') logger.info(f'Codec: [{", ".join(map(str, self.vcodec))}]') if self.need_align: self.dataroot_meta_info = opt['dataroot_meta_info'] self.face_aligner = FaceAligner( upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True,) def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient( self.io_backend_opt.pop('type'), **self.io_backend_opt) key = self.keys[index] real_clip_path = '/'.join(key.split('/')[:-2]) clip_length = int(key.split('/')[-2]) frame_idx = int(key.split('/')[-1]) clip_name = real_clip_path.split('/')[-1] paths = sorted(list(scandir(os.path.join( self.gt_root, clip_name)))) # determine the neighboring frames interval = random.choice(self.interval_list) # exceed the length, re-select a new clip while (clip_length - self.num_frame * interval) < 0: interval = random.choice(self.interval_list) # ensure not exceeding the borders # print(self.num_frame, type(self.num_frame)) # print(interval, type(interval)) start_frame_idx = frame_idx - self.num_frame // 2 * interval end_frame_idx = frame_idx + self.num_frame // 2 * interval # flag = (start_frame_idx < 0) or (end_frame_idx > clip_length) # print(key, start_frame_idx, end_frame_idx, interval, flag) # each clip has 100+ frames while (start_frame_idx < 0) or (end_frame_idx > clip_length): frame_idx = random.randint(self.num_frame//2 * interval, clip_length - self.num_frame//2 * interval) start_frame_idx = frame_idx - self.num_frame // 2 * interval end_frame_idx = frame_idx + self.num_frame // 2 * interval neighbor_list = list( range(start_frame_idx, end_frame_idx, interval)) # print(start_frame_idx, end_frame_idx, frame_idx, interval) # random reverse if self.random_reverse and random.random() < 0.5: neighbor_list.reverse() assert len(neighbor_list) == self.num_frame, ( f'Wrong length of neighbor list: {len(neighbor_list)}') # get the neighboring GT frames img_gts = [] if self.need_align: clip_info_path = os.path.join( self.dataroot_meta_info, f'{clip_name}.txt') clip_info = [] with open(clip_info_path, 'r', encoding='utf-8') as fin: for line in fin: line = line.strip() if line.startswith('0'): clip_info.append(line) for neighbor in neighbor_list: assert paths[neighbor] == clip_info[neighbor].split(' ')[0], \ f'{clip_name}: Mismatch frame {paths[neighbor]} and {clip_info[neighbor]}' # img_gt_path = os.path.join( # self.gt_root, clip_name, f'{neighbor:08d}.png') img_gt_path = os.path.join( self.gt_root, clip_name, paths[neighbor]) # img_bytes = self.file_client.get(img_gt_path, 'gt') # img_gt = imfrombytes(img_bytes, float32=True) # img_gt = cv2.imread(img_gt_path) / 255.0 img_gt = np.asarray(Image.open(img_gt_path))[:, :, ::-1] / 255.0 img_gts.append(img_gt) # augmentation - flip, rotate img_gts = augment(img_gts, self.opt['use_flip'], self.opt['use_rot']) # ------------- generate LQ frames --------------# # add blur kernel = random_mixed_kernels(self.kernel_list, self.kernel_prob, self.blur_kernel_size, self.blur_x_sigma, self.blur_y_sigma) img_lqs = [cv2.filter2D(v, -1, kernel) for v in img_gts] # add noise img_lqs = [ random_add_gaussian_noise(v, self.noise_range, gray_prob=0.5, clip=True, rounds=False) for v in img_lqs ] # downsample original_height, original_width = img_gts[0].shape[0:2] resize_type = random.choices( [cv2.INTER_AREA, cv2.INTER_LINEAR, cv2.INTER_CUBIC], self.resize_prob)[0] resized_height, resized_width = int( original_height // self.scale), int(original_width // self.scale) # ensure the resized_height and resized_width are even numbers img_lqs = [cv2.resize(v, (resized_width, resized_height), interpolation=resize_type) for v in img_lqs] # add noise img_lqs = [ random_add_gaussian_noise(v, self.noise_range, gray_prob=0.5, clip=True, rounds=False) for v in img_lqs ] # ffmpeg crf = np.random.randint(self.crf_range[0], self.crf_range[1]) codec = random.choices(self.vcodec, self.vcodec_prob)[0] buf = io.BytesIO() with av.open(buf, 'w', 'mp4') as container: stream = container.add_stream(codec, rate=1) stream.height = resized_height stream.width = resized_width stream.pix_fmt = 'yuv420p' stream.options = {'crf': str(crf)} for img_lq in img_lqs: img_lq = np.clip(img_lq * 255, 0, 255).astype(np.uint8) frame = av.VideoFrame.from_ndarray(img_lq, format='rgb24') frame.pict_type = 0 # Changed from 'NONE' to 0 for packet in stream.encode(frame): container.mux(packet) # Flush stream for packet in stream.encode(): container.mux(packet) img_lqs = [] with av.open(buf, 'r', 'mp4') as container: if container.streams.video: for frame in container.decode(**{'video': 0}): img_lqs.append(frame.to_rgb().to_ndarray() / 255.) assert len(img_lqs) == len(img_gts), 'Wrong length' # ------------ Align -------------# if self.need_align: align_lqs, align_gts = [], [] for frame_idx, (img_lq, img_gt) in enumerate(zip(img_lqs, img_gts)): landmarks_str = clip_info[frame_idx].split(' ')[1:] # print(clip_name, paths[neighbor], landmarks_str) landmarks = np.array([float(x) for x in landmarks_str]).reshape(5, 2) self.face_aligner.clean_all() # align and warp each face img_lq, img_gt = self.face_aligner.align_pair_face( img_lq, img_gt, landmarks) align_lqs.append(img_lq) align_gts.append(img_gt) img_lqs, img_gts = align_lqs, align_gts # ------------- end --------------# img_gts = img2tensor(img_gts) img_lqs = img2tensor(img_lqs) img_gts = torch.stack(img_gts, dim=0) img_lqs = torch.stack(img_lqs, dim=0) if self.normalize: normalize(img_lqs, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True) normalize(img_gts, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True) # img_lqs: (t, c, h, w) # img_gts: (t, c, h, w) # key: str return {'lq': img_lqs, 'gt': img_gts, 'key': key} def __len__(self): return len(self.keys)