|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import argparse |
|
from tqdm import tqdm |
|
from PIL import Image |
|
import functools |
|
from multiprocessing import Pool |
|
import math |
|
|
|
|
|
def arg_parser(): |
|
parser = argparse.ArgumentParser('Generate cropped image pairs from image crop list') |
|
|
|
parser.add_argument('--crops', type=str, required=True, help='crop file') |
|
parser.add_argument('--root-dir', type=str, required=True, help='root directory') |
|
parser.add_argument('--output-dir', type=str, required=True, help='output directory') |
|
parser.add_argument('--imsize', type=int, default=256, help='size of the crops') |
|
parser.add_argument('--nthread', type=int, required=True, help='number of simultaneous threads') |
|
parser.add_argument('--max-subdir-levels', type=int, default=5, help='maximum number of subdirectories') |
|
parser.add_argument('--ideal-number-pairs-in-dir', type=int, default=500, help='number of pairs stored in a dir') |
|
return parser |
|
|
|
|
|
def main(args): |
|
listing_path = os.path.join(args.output_dir, 'listing.txt') |
|
|
|
print(f'Loading list of crops ... ({args.nthread} threads)') |
|
crops, num_crops_to_generate = load_crop_file(args.crops) |
|
|
|
print(f'Preparing jobs ({len(crops)} candidate image pairs)...') |
|
num_levels = min(math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), args.max_subdir_levels) |
|
num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1/num_levels)) |
|
|
|
jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir) |
|
del crops |
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map |
|
call = functools.partial(save_image_crops, args) |
|
|
|
print(f"Generating cropped images to {args.output_dir} ...") |
|
with open(listing_path, 'w') as listing: |
|
listing.write('# pair_path\n') |
|
for results in tqdm(mmap(call, jobs), total=len(jobs)): |
|
for path in results: |
|
listing.write(f'{path}\n') |
|
print('Finished writing listing to', listing_path) |
|
|
|
|
|
def load_crop_file(path): |
|
data = open(path).read().splitlines() |
|
pairs = [] |
|
num_crops_to_generate = 0 |
|
for line in tqdm(data): |
|
if line.startswith('#'): |
|
continue |
|
line = line.split(', ') |
|
if len(line) < 8: |
|
img1, img2, rotation = line |
|
pairs.append((img1, img2, int(rotation), [])) |
|
else: |
|
l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line) |
|
rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2) |
|
pairs[-1][-1].append((rect1, rect2)) |
|
num_crops_to_generate += 1 |
|
return pairs, num_crops_to_generate |
|
|
|
|
|
def prepare_jobs(pairs, num_levels, num_pairs_in_dir): |
|
jobs = [] |
|
powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))] |
|
|
|
def get_path(idx): |
|
idx_array = [] |
|
d = idx |
|
for level in range(num_levels - 1): |
|
idx_array.append(idx // powers[level]) |
|
idx = idx % powers[level] |
|
idx_array.append(d) |
|
return '/'.join(map(lambda x: hex(x)[2:], idx_array)) |
|
|
|
idx = 0 |
|
for pair_data in tqdm(pairs): |
|
img1, img2, rotation, crops = pair_data |
|
if -60 <= rotation and rotation <= 60: |
|
rotation = 0 |
|
paths = [get_path(idx + k) for k in range(len(crops))] |
|
idx += len(crops) |
|
jobs.append(((img1, img2), rotation, crops, paths)) |
|
return jobs |
|
|
|
|
|
def load_image(path): |
|
try: |
|
return Image.open(path).convert('RGB') |
|
except Exception as e: |
|
print('skipping', path, e) |
|
raise OSError() |
|
|
|
|
|
def save_image_crops(args, data): |
|
|
|
img_pair, rot, crops, paths = data |
|
try: |
|
img1, img2 = [load_image(os.path.join(args.root_dir, impath)) for impath in img_pair] |
|
except OSError as e: |
|
return [] |
|
|
|
def area(sz): |
|
return sz[0] * sz[1] |
|
|
|
tgt_size = (args.imsize, args.imsize) |
|
|
|
def prepare_crop(img, rect, rot=0): |
|
|
|
img = img.crop(rect) |
|
|
|
|
|
interp = Image.Resampling.LANCZOS if area(img.size) > 4*area(tgt_size) else Image.Resampling.BICUBIC |
|
img = img.resize(tgt_size, resample=interp) |
|
|
|
|
|
rot90 = (round(rot/90) % 4) * 90 |
|
if rot90 == 90: |
|
img = img.transpose(Image.Transpose.ROTATE_90) |
|
elif rot90 == 180: |
|
img = img.transpose(Image.Transpose.ROTATE_180) |
|
elif rot90 == 270: |
|
img = img.transpose(Image.Transpose.ROTATE_270) |
|
return img |
|
|
|
results = [] |
|
for (rect1, rect2), path in zip(crops, paths): |
|
crop1 = prepare_crop(img1, rect1) |
|
crop2 = prepare_crop(img2, rect2, rot) |
|
|
|
fullpath1 = os.path.join(args.output_dir, path+'_1.jpg') |
|
fullpath2 = os.path.join(args.output_dir, path+'_2.jpg') |
|
os.makedirs(os.path.dirname(fullpath1), exist_ok=True) |
|
|
|
assert not os.path.isfile(fullpath1), fullpath1 |
|
assert not os.path.isfile(fullpath2), fullpath2 |
|
crop1.save(fullpath1) |
|
crop2.save(fullpath2) |
|
results.append(path) |
|
|
|
return results |
|
|
|
|
|
if __name__ == '__main__': |
|
args = arg_parser().parse_args() |
|
main(args) |
|
|
|
|