|
import math |
|
import os |
|
|
|
import PIL.Image as Image |
|
import numpy as np |
|
import torch |
|
import torchvision.transforms as vtrans |
|
import tqdm |
|
|
|
|
|
def main(fip, fod): |
|
max_overex_rate = 0.25 |
|
steps = 20 |
|
num_gen = 4 |
|
|
|
im = Image.open(fip) |
|
im = vtrans.ToTensor()(im) |
|
im_max = torch.flatten(torch.max(im, dim=0, keepdim=True).values) |
|
mag = 1. / torch.topk(im_max, math.floor(len(im_max) * max_overex_rate + 1)).values |
|
mag = mag[range(0, len(mag), int(len(mag) * (1. / steps)))] |
|
mag_diff = torch.diff(mag, 1) |
|
mag = mag[:-1] |
|
|
|
top_mag_diff = torch.topk(mag_diff, num_gen).values |
|
min_gain = top_mag_diff[top_mag_diff > 0][-1] |
|
min_mag = mag[0] |
|
max_mag = mag[mag_diff > min_gain][-1] |
|
fn, ext = os.path.basename(fip).split('.') |
|
bar.set_description(f'{fn}: {min_gain}') |
|
ma = np.arange(1, min_mag - min_gain, min_gain * 2) |
|
if len(ma) > num_gen: |
|
mags = np.append(np.linspace(1, min_mag - min_gain, num_gen), |
|
np.linspace(min_mag, max_mag, num_gen)) |
|
elif len(ma) == num_gen: |
|
mags = np.append(ma, np.linspace(min_mag, max_mag, num_gen)) |
|
else: |
|
mags = np.linspace(1, max_mag, num_gen * 2) |
|
|
|
im = Image.open(fip) |
|
im_raw = vtrans.ToTensor()(im) |
|
|
|
for i, mag in enumerate(mags): |
|
im = im_raw * mag |
|
im.clamp_max_(1.) |
|
fop = os.path.join(fod, f'{fn}_{i}.{ext}') |
|
|
|
if not os.path.exists(fop): |
|
vtrans.ToPILImage()(im).save(fop) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
fid = './data/LOL/train/images' |
|
fod = './data/LOL/train/images_aug' |
|
os.makedirs(fod, exist_ok=True) |
|
|
|
bar = tqdm.tqdm(os.listdir(fid)) |
|
for fn in bar: |
|
fip = os.path.join(fid, fn) |
|
main(fip, fod) |
|
|