File size: 5,494 Bytes
e4bf056
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
# 
# --------------------------------------------------------
# Extracting crops for pre-training
# --------------------------------------------------------

import os
import argparse
from tqdm import tqdm
from PIL import Image
import functools
from multiprocessing import Pool
import math


def arg_parser():
    parser = argparse.ArgumentParser('Generate cropped image pairs from image crop list')

    parser.add_argument('--crops', type=str, required=True, help='crop file')
    parser.add_argument('--root-dir', type=str, required=True, help='root directory')
    parser.add_argument('--output-dir', type=str, required=True, help='output directory')
    parser.add_argument('--imsize', type=int, default=256, help='size of the crops')
    parser.add_argument('--nthread', type=int, required=True, help='number of simultaneous threads')
    parser.add_argument('--max-subdir-levels', type=int, default=5, help='maximum number of subdirectories')
    parser.add_argument('--ideal-number-pairs-in-dir', type=int, default=500, help='number of pairs stored in a dir')
    return parser


def main(args):
    listing_path = os.path.join(args.output_dir, 'listing.txt')

    print(f'Loading list of crops ... ({args.nthread} threads)')
    crops, num_crops_to_generate = load_crop_file(args.crops)

    print(f'Preparing jobs ({len(crops)} candidate image pairs)...')
    num_levels = min(math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), args.max_subdir_levels)
    num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1/num_levels))

    jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir)
    del crops

    os.makedirs(args.output_dir, exist_ok=True)
    mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map
    call = functools.partial(save_image_crops, args)

    print(f"Generating cropped images to {args.output_dir} ...")
    with open(listing_path, 'w') as listing:
        listing.write('# pair_path\n')
        for results in tqdm(mmap(call, jobs), total=len(jobs)):
            for path in results:
                listing.write(f'{path}\n')
    print('Finished writing listing to', listing_path)


def load_crop_file(path):
    data = open(path).read().splitlines()
    pairs = []
    num_crops_to_generate = 0
    for line in tqdm(data):
        if line.startswith('#'):
            continue
        line = line.split(', ')
        if len(line) < 8:
            img1, img2, rotation = line
            pairs.append((img1, img2, int(rotation), []))
        else:
            l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line)
            rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2)
            pairs[-1][-1].append((rect1, rect2))
            num_crops_to_generate += 1
    return pairs, num_crops_to_generate


def prepare_jobs(pairs, num_levels, num_pairs_in_dir):
    jobs = []
    powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))]

    def get_path(idx):
        idx_array = []
        d = idx
        for level in range(num_levels - 1):
            idx_array.append(idx // powers[level])
            idx = idx % powers[level]
        idx_array.append(d)
        return '/'.join(map(lambda x: hex(x)[2:], idx_array))

    idx = 0
    for pair_data in tqdm(pairs):
        img1, img2, rotation, crops = pair_data
        if -60 <= rotation and rotation <= 60:
            rotation = 0  # most likely not a true rotation
        paths = [get_path(idx + k) for k in range(len(crops))]
        idx += len(crops)
        jobs.append(((img1, img2), rotation, crops, paths))
    return jobs


def load_image(path):
    try:
        return Image.open(path).convert('RGB')
    except Exception as e:
        print('skipping', path, e)
        raise OSError()


def save_image_crops(args, data):
    # load images
    img_pair, rot, crops, paths = data
    try:
        img1, img2 = [load_image(os.path.join(args.root_dir, impath)) for impath in img_pair]
    except OSError as e:
        return []

    def area(sz):
        return sz[0] * sz[1]

    tgt_size = (args.imsize, args.imsize)

    def prepare_crop(img, rect, rot=0):
        # actual crop
        img = img.crop(rect)

        # resize to desired size
        interp = Image.Resampling.LANCZOS if area(img.size) > 4*area(tgt_size) else Image.Resampling.BICUBIC
        img = img.resize(tgt_size, resample=interp)

        # rotate the image
        rot90 = (round(rot/90) % 4) * 90
        if rot90 == 90:
            img = img.transpose(Image.Transpose.ROTATE_90)
        elif rot90 == 180:
            img = img.transpose(Image.Transpose.ROTATE_180)
        elif rot90 == 270:
            img = img.transpose(Image.Transpose.ROTATE_270)
        return img

    results = []
    for (rect1, rect2), path in zip(crops, paths):
        crop1 = prepare_crop(img1, rect1)
        crop2 = prepare_crop(img2, rect2, rot)

        fullpath1 = os.path.join(args.output_dir,  path+'_1.jpg')
        fullpath2 = os.path.join(args.output_dir,  path+'_2.jpg')
        os.makedirs(os.path.dirname(fullpath1), exist_ok=True)

        assert not os.path.isfile(fullpath1), fullpath1
        assert not os.path.isfile(fullpath2), fullpath2
        crop1.save(fullpath1)
        crop2.save(fullpath2)
        results.append(path)

    return results


if __name__ == '__main__':
    args = arg_parser().parse_args()
    main(args)