Spaces:
Running
on
T4
Running
on
T4
# -*- coding: utf-8 -*- | |
import argparse | |
import cv2 | |
import torch | |
import os, shutil, time | |
import sys | |
from multiprocessing import Process, Queue | |
from os import path as osp | |
from tqdm import tqdm | |
import copy | |
import warnings | |
import gc | |
warnings.filterwarnings("ignore") | |
# import same folder files # | |
root_path = os.path.abspath('.') | |
sys.path.append(root_path) | |
from degradation.ESR.utils import np2tensor | |
from degradation.ESR.degradations_functionality import * | |
from degradation.ESR.diffjpeg import * | |
from degradation.degradation_esr import degradation_v1 | |
from opt import opt | |
os.environ['CUDA_VISIBLE_DEVICES'] = opt['CUDA_VISIBLE_DEVICES'] #'0,1' | |
def crop_process(path, crop_size, lr_dataset_path, output_index): | |
''' crop the image here (also do usm here) | |
Args: | |
path (str): Path of the image | |
crop_size (int): Crop size | |
lr_dataset_path (str): LR dataset path folder name | |
output_index (int): The index we used to store images | |
Returns: | |
output_index (int): The next index we need to use to store images | |
''' | |
# read image | |
img = cv2.imread(path) | |
height, width = img.shape[0:2] | |
res_store = [] | |
crop_num = (height//crop_size)*(width//crop_size) | |
# Use shift offset to make image more cover origional image size | |
shift_offset_h, shift_offset_w = 0, 0 | |
# Select all sub-frames order by order (not randomly select here) | |
choices = [i for i in range(crop_num)] | |
shift_offset_h = 0 #random.randint(0, height - crop_size * (height//crop_size)) | |
shift_offset_w = 0 #random.randint(0, width - crop_size * (width//crop_size)) | |
for choice in choices: | |
row_num = (width//crop_size) | |
x, y = crop_size * (choice // row_num), crop_size * (choice % row_num) | |
# add offset | |
res_store.append((x, y)) | |
for (h, w) in res_store: | |
cropped_img = img[h+shift_offset_h : h+crop_size+shift_offset_h, w+shift_offset_w : w+crop_size+shift_offset_w, ...] | |
cropped_img = np.ascontiguousarray(cropped_img) | |
cv2.imwrite(osp.join(lr_dataset_path, f'img_{output_index:06d}.png'), cropped_img, [cv2.IMWRITE_PNG_COMPRESSION, 0]) # Save in lossless mode | |
output_index += 1 | |
return output_index | |
def single_process(queue, opt, process_id): | |
''' Multi Process instance | |
Args: | |
queue (multiprocessing.Queue): The input queue | |
opt (dict): The setting we need to use | |
process_id (int): The id we used to store temporary file | |
''' | |
# Initialization | |
obj_img = degradation_v1() | |
while True: | |
items = queue.get() | |
if items == None: | |
break | |
input_path, store_path = items | |
# Reset kernels in every degradation batch for ESR | |
obj_img.reset_kernels(opt) | |
# Read all images and transform them to tensor | |
img_bgr = cv2.imread(input_path) | |
out = np2tensor(img_bgr) # tensor | |
# ESR Degradation execution | |
obj_img.degradate_process(out, opt, store_path, process_id, verbose = False) | |
def generate_low_res_esr(org_opt, verbose=False): | |
''' Generate LR dataset from HR ones by ESR degradation | |
Args: | |
org_opt (dict): The setting we will use | |
verbose (bool): Whether we print out some information | |
''' | |
# Prepare folders and files | |
input_folder = org_opt['input_folder'] | |
save_folder = org_opt['save_folder'] | |
if osp.exists(save_folder): | |
shutil.rmtree(save_folder) | |
if osp.exists("tmp"): | |
shutil.rmtree("tmp") | |
os.makedirs(save_folder) | |
os.makedirs("tmp") | |
if os.path.exists("datasets/degradation_log.txt"): | |
os.remove("datasets/degradation_log.txt") | |
# Scan all images | |
input_img_lists, output_img_lists = [], [] | |
for file in sorted(os.listdir(input_folder)): | |
input_img_lists.append(osp.join(input_folder, file)) | |
output_img_lists.append(osp.join("tmp", file)) | |
assert(len(input_img_lists) == len(output_img_lists)) | |
# Multi-Process Preparation | |
parallel_num = opt['parallel_num'] | |
queue = Queue() | |
# Save all files in the Queue | |
for idx in range(len(input_img_lists)): | |
# Find the needed img lists | |
queue.put((input_img_lists[idx], output_img_lists[idx])) | |
# Start the process | |
Processes = [] | |
for process_id in range(parallel_num): | |
p1 = Process(target=single_process, args =(queue, opt, process_id, )) | |
p1.start() | |
Processes.append(p1) | |
for _ in range(parallel_num): | |
queue.put(None) # Used to end the process | |
# print("All Process starts") | |
# tqdm wait progress | |
for idx in tqdm(range(0, len(output_img_lists)), desc ="Degradation"): | |
while True: | |
if os.path.exists(output_img_lists[idx]): | |
break | |
time.sleep(0.1) | |
# Merge all processes | |
for process in Processes: | |
process.join() | |
# Crop images under folder "tmp" | |
output_index = 1 | |
for img_name in sorted(os.listdir("tmp")): | |
path = os.path.join("tmp", img_name) | |
output_index = crop_process(path, opt['hr_size']//opt['scale'], opt['save_folder'], output_index) | |
def main(args): | |
opt['input_folder'] = args.input | |
opt['save_folder'] = args.output | |
generate_low_res_esr(opt) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input', type=str, default = opt["full_patch_source"], help='Input folder') | |
parser.add_argument('--output', type=str, default = opt["lr_dataset_path"], help='Output folder') | |
args = parser.parse_args() | |
main(args) |