Spaces:
Running
on
T4
Running
on
T4
File size: 5,708 Bytes
561c629 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
# -*- coding: utf-8 -*-
import argparse
import cv2
import torch
import os, shutil, time
import sys
from multiprocessing import Process, Queue
from os import path as osp
from tqdm import tqdm
import copy
import warnings
import gc
warnings.filterwarnings("ignore")
# import same folder files #
root_path = os.path.abspath('.')
sys.path.append(root_path)
from degradation.ESR.utils import np2tensor
from degradation.ESR.degradations_functionality import *
from degradation.ESR.diffjpeg import *
from degradation.degradation_esr import degradation_v1
from opt import opt
os.environ['CUDA_VISIBLE_DEVICES'] = opt['CUDA_VISIBLE_DEVICES'] #'0,1'
def crop_process(path, crop_size, lr_dataset_path, output_index):
''' crop the image here (also do usm here)
Args:
path (str): Path of the image
crop_size (int): Crop size
lr_dataset_path (str): LR dataset path folder name
output_index (int): The index we used to store images
Returns:
output_index (int): The next index we need to use to store images
'''
# read image
img = cv2.imread(path)
height, width = img.shape[0:2]
res_store = []
crop_num = (height//crop_size)*(width//crop_size)
# Use shift offset to make image more cover origional image size
shift_offset_h, shift_offset_w = 0, 0
# Select all sub-frames order by order (not randomly select here)
choices = [i for i in range(crop_num)]
shift_offset_h = 0 #random.randint(0, height - crop_size * (height//crop_size))
shift_offset_w = 0 #random.randint(0, width - crop_size * (width//crop_size))
for choice in choices:
row_num = (width//crop_size)
x, y = crop_size * (choice // row_num), crop_size * (choice % row_num)
# add offset
res_store.append((x, y))
for (h, w) in res_store:
cropped_img = img[h+shift_offset_h : h+crop_size+shift_offset_h, w+shift_offset_w : w+crop_size+shift_offset_w, ...]
cropped_img = np.ascontiguousarray(cropped_img)
cv2.imwrite(osp.join(lr_dataset_path, f'img_{output_index:06d}.png'), cropped_img, [cv2.IMWRITE_PNG_COMPRESSION, 0]) # Save in lossless mode
output_index += 1
return output_index
def single_process(queue, opt, process_id):
''' Multi Process instance
Args:
queue (multiprocessing.Queue): The input queue
opt (dict): The setting we need to use
process_id (int): The id we used to store temporary file
'''
# Initialization
obj_img = degradation_v1()
while True:
items = queue.get()
if items == None:
break
input_path, store_path = items
# Reset kernels in every degradation batch for ESR
obj_img.reset_kernels(opt)
# Read all images and transform them to tensor
img_bgr = cv2.imread(input_path)
out = np2tensor(img_bgr) # tensor
# ESR Degradation execution
obj_img.degradate_process(out, opt, store_path, process_id, verbose = False)
@torch.no_grad()
def generate_low_res_esr(org_opt, verbose=False):
''' Generate LR dataset from HR ones by ESR degradation
Args:
org_opt (dict): The setting we will use
verbose (bool): Whether we print out some information
'''
# Prepare folders and files
input_folder = org_opt['input_folder']
save_folder = org_opt['save_folder']
if osp.exists(save_folder):
shutil.rmtree(save_folder)
if osp.exists("tmp"):
shutil.rmtree("tmp")
os.makedirs(save_folder)
os.makedirs("tmp")
if os.path.exists("datasets/degradation_log.txt"):
os.remove("datasets/degradation_log.txt")
# Scan all images
input_img_lists, output_img_lists = [], []
for file in sorted(os.listdir(input_folder)):
input_img_lists.append(osp.join(input_folder, file))
output_img_lists.append(osp.join("tmp", file))
assert(len(input_img_lists) == len(output_img_lists))
# Multi-Process Preparation
parallel_num = opt['parallel_num']
queue = Queue()
# Save all files in the Queue
for idx in range(len(input_img_lists)):
# Find the needed img lists
queue.put((input_img_lists[idx], output_img_lists[idx]))
# Start the process
Processes = []
for process_id in range(parallel_num):
p1 = Process(target=single_process, args =(queue, opt, process_id, ))
p1.start()
Processes.append(p1)
for _ in range(parallel_num):
queue.put(None) # Used to end the process
# print("All Process starts")
# tqdm wait progress
for idx in tqdm(range(0, len(output_img_lists)), desc ="Degradation"):
while True:
if os.path.exists(output_img_lists[idx]):
break
time.sleep(0.1)
# Merge all processes
for process in Processes:
process.join()
# Crop images under folder "tmp"
output_index = 1
for img_name in sorted(os.listdir("tmp")):
path = os.path.join("tmp", img_name)
output_index = crop_process(path, opt['hr_size']//opt['scale'], opt['save_folder'], output_index)
def main(args):
opt['input_folder'] = args.input
opt['save_folder'] = args.output
generate_low_res_esr(opt)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, default = opt["full_patch_source"], help='Input folder')
parser.add_argument('--output', type=str, default = opt["lr_dataset_path"], help='Output folder')
args = parser.parse_args()
main(args) |