# -*- coding: utf-8 -*- import argparse import cv2 import torch import os, shutil, time import sys from multiprocessing import Process, Queue from os import path as osp from tqdm import tqdm import copy import warnings import gc warnings.filterwarnings("ignore") # import same folder files # root_path = os.path.abspath('.') sys.path.append(root_path) from degradation.ESR.utils import np2tensor from degradation.ESR.degradations_functionality import * from degradation.ESR.diffjpeg import * from degradation.degradation_esr import degradation_v1 from opt import opt os.environ['CUDA_VISIBLE_DEVICES'] = opt['CUDA_VISIBLE_DEVICES'] #'0,1' def crop_process(path, crop_size, lr_dataset_path, output_index): ''' crop the image here (also do usm here) Args: path (str): Path of the image crop_size (int): Crop size lr_dataset_path (str): LR dataset path folder name output_index (int): The index we used to store images Returns: output_index (int): The next index we need to use to store images ''' # read image img = cv2.imread(path) height, width = img.shape[0:2] res_store = [] crop_num = (height//crop_size)*(width//crop_size) # Use shift offset to make image more cover origional image size shift_offset_h, shift_offset_w = 0, 0 # Select all sub-frames order by order (not randomly select here) choices = [i for i in range(crop_num)] shift_offset_h = 0 #random.randint(0, height - crop_size * (height//crop_size)) shift_offset_w = 0 #random.randint(0, width - crop_size * (width//crop_size)) for choice in choices: row_num = (width//crop_size) x, y = crop_size * (choice // row_num), crop_size * (choice % row_num) # add offset res_store.append((x, y)) for (h, w) in res_store: cropped_img = img[h+shift_offset_h : h+crop_size+shift_offset_h, w+shift_offset_w : w+crop_size+shift_offset_w, ...] cropped_img = np.ascontiguousarray(cropped_img) cv2.imwrite(osp.join(lr_dataset_path, f'img_{output_index:06d}.png'), cropped_img, [cv2.IMWRITE_PNG_COMPRESSION, 0]) # Save in lossless mode output_index += 1 return output_index def single_process(queue, opt, process_id): ''' Multi Process instance Args: queue (multiprocessing.Queue): The input queue opt (dict): The setting we need to use process_id (int): The id we used to store temporary file ''' # Initialization obj_img = degradation_v1() while True: items = queue.get() if items == None: break input_path, store_path = items # Reset kernels in every degradation batch for ESR obj_img.reset_kernels(opt) # Read all images and transform them to tensor img_bgr = cv2.imread(input_path) out = np2tensor(img_bgr) # tensor # ESR Degradation execution obj_img.degradate_process(out, opt, store_path, process_id, verbose = False) @torch.no_grad() def generate_low_res_esr(org_opt, verbose=False): ''' Generate LR dataset from HR ones by ESR degradation Args: org_opt (dict): The setting we will use verbose (bool): Whether we print out some information ''' # Prepare folders and files input_folder = org_opt['input_folder'] save_folder = org_opt['save_folder'] if osp.exists(save_folder): shutil.rmtree(save_folder) if osp.exists("tmp"): shutil.rmtree("tmp") os.makedirs(save_folder) os.makedirs("tmp") if os.path.exists("datasets/degradation_log.txt"): os.remove("datasets/degradation_log.txt") # Scan all images input_img_lists, output_img_lists = [], [] for file in sorted(os.listdir(input_folder)): input_img_lists.append(osp.join(input_folder, file)) output_img_lists.append(osp.join("tmp", file)) assert(len(input_img_lists) == len(output_img_lists)) # Multi-Process Preparation parallel_num = opt['parallel_num'] queue = Queue() # Save all files in the Queue for idx in range(len(input_img_lists)): # Find the needed img lists queue.put((input_img_lists[idx], output_img_lists[idx])) # Start the process Processes = [] for process_id in range(parallel_num): p1 = Process(target=single_process, args =(queue, opt, process_id, )) p1.start() Processes.append(p1) for _ in range(parallel_num): queue.put(None) # Used to end the process # print("All Process starts") # tqdm wait progress for idx in tqdm(range(0, len(output_img_lists)), desc ="Degradation"): while True: if os.path.exists(output_img_lists[idx]): break time.sleep(0.1) # Merge all processes for process in Processes: process.join() # Crop images under folder "tmp" output_index = 1 for img_name in sorted(os.listdir("tmp")): path = os.path.join("tmp", img_name) output_index = crop_process(path, opt['hr_size']//opt['scale'], opt['save_folder'], output_index) def main(args): opt['input_folder'] = args.input opt['save_folder'] = args.output generate_low_res_esr(opt) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--input', type=str, default = opt["full_patch_source"], help='Input folder') parser.add_argument('--output', type=str, default = opt["lr_dataset_path"], help='Output folder') args = parser.parse_args() main(args)