File size: 5,494 Bytes
f53b39e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Extracting crops for pre-training
# --------------------------------------------------------
import os
import argparse
from tqdm import tqdm
from PIL import Image
import functools
from multiprocessing import Pool
import math
def arg_parser():
parser = argparse.ArgumentParser('Generate cropped image pairs from image crop list')
parser.add_argument('--crops', type=str, required=True, help='crop file')
parser.add_argument('--root-dir', type=str, required=True, help='root directory')
parser.add_argument('--output-dir', type=str, required=True, help='output directory')
parser.add_argument('--imsize', type=int, default=256, help='size of the crops')
parser.add_argument('--nthread', type=int, required=True, help='number of simultaneous threads')
parser.add_argument('--max-subdir-levels', type=int, default=5, help='maximum number of subdirectories')
parser.add_argument('--ideal-number-pairs-in-dir', type=int, default=500, help='number of pairs stored in a dir')
return parser
def main(args):
listing_path = os.path.join(args.output_dir, 'listing.txt')
print(f'Loading list of crops ... ({args.nthread} threads)')
crops, num_crops_to_generate = load_crop_file(args.crops)
print(f'Preparing jobs ({len(crops)} candidate image pairs)...')
num_levels = min(math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), args.max_subdir_levels)
num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1/num_levels))
jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir)
del crops
os.makedirs(args.output_dir, exist_ok=True)
mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map
call = functools.partial(save_image_crops, args)
print(f"Generating cropped images to {args.output_dir} ...")
with open(listing_path, 'w') as listing:
listing.write('# pair_path\n')
for results in tqdm(mmap(call, jobs), total=len(jobs)):
for path in results:
listing.write(f'{path}\n')
print('Finished writing listing to', listing_path)
def load_crop_file(path):
data = open(path).read().splitlines()
pairs = []
num_crops_to_generate = 0
for line in tqdm(data):
if line.startswith('#'):
continue
line = line.split(', ')
if len(line) < 8:
img1, img2, rotation = line
pairs.append((img1, img2, int(rotation), []))
else:
l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line)
rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2)
pairs[-1][-1].append((rect1, rect2))
num_crops_to_generate += 1
return pairs, num_crops_to_generate
def prepare_jobs(pairs, num_levels, num_pairs_in_dir):
jobs = []
powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))]
def get_path(idx):
idx_array = []
d = idx
for level in range(num_levels - 1):
idx_array.append(idx // powers[level])
idx = idx % powers[level]
idx_array.append(d)
return '/'.join(map(lambda x: hex(x)[2:], idx_array))
idx = 0
for pair_data in tqdm(pairs):
img1, img2, rotation, crops = pair_data
if -60 <= rotation and rotation <= 60:
rotation = 0 # most likely not a true rotation
paths = [get_path(idx + k) for k in range(len(crops))]
idx += len(crops)
jobs.append(((img1, img2), rotation, crops, paths))
return jobs
def load_image(path):
try:
return Image.open(path).convert('RGB')
except Exception as e:
print('skipping', path, e)
raise OSError()
def save_image_crops(args, data):
# load images
img_pair, rot, crops, paths = data
try:
img1, img2 = [load_image(os.path.join(args.root_dir, impath)) for impath in img_pair]
except OSError as e:
return []
def area(sz):
return sz[0] * sz[1]
tgt_size = (args.imsize, args.imsize)
def prepare_crop(img, rect, rot=0):
# actual crop
img = img.crop(rect)
# resize to desired size
interp = Image.Resampling.LANCZOS if area(img.size) > 4*area(tgt_size) else Image.Resampling.BICUBIC
img = img.resize(tgt_size, resample=interp)
# rotate the image
rot90 = (round(rot/90) % 4) * 90
if rot90 == 90:
img = img.transpose(Image.Transpose.ROTATE_90)
elif rot90 == 180:
img = img.transpose(Image.Transpose.ROTATE_180)
elif rot90 == 270:
img = img.transpose(Image.Transpose.ROTATE_270)
return img
results = []
for (rect1, rect2), path in zip(crops, paths):
crop1 = prepare_crop(img1, rect1)
crop2 = prepare_crop(img2, rect2, rot)
fullpath1 = os.path.join(args.output_dir, path+'_1.jpg')
fullpath2 = os.path.join(args.output_dir, path+'_2.jpg')
os.makedirs(os.path.dirname(fullpath1), exist_ok=True)
assert not os.path.isfile(fullpath1), fullpath1
assert not os.path.isfile(fullpath2), fullpath2
crop1.save(fullpath1)
crop2.save(fullpath2)
results.append(path)
return results
if __name__ == '__main__':
args = arg_parser().parse_args()
main(args)
|