Bread / exposure_augment.py
huqiming513's picture
Upload 14 files
e538b68
raw
history blame
1.76 kB
import math
import os
import PIL.Image as Image
import numpy as np
import torch
import torchvision.transforms as vtrans
import tqdm
def main(fip, fod):
max_overex_rate = 0.25
steps = 20
num_gen = 4
im = Image.open(fip)
im = vtrans.ToTensor()(im)
im_max = torch.flatten(torch.max(im, dim=0, keepdim=True).values)
mag = 1. / torch.topk(im_max, math.floor(len(im_max) * max_overex_rate + 1)).values
mag = mag[range(0, len(mag), int(len(mag) * (1. / steps)))]
mag_diff = torch.diff(mag, 1)
mag = mag[:-1]
top_mag_diff = torch.topk(mag_diff, num_gen).values
min_gain = top_mag_diff[top_mag_diff > 0][-1]
min_mag = mag[0]
max_mag = mag[mag_diff > min_gain][-1]
fn, ext = os.path.basename(fip).split('.')
bar.set_description(f'{fn}: {min_gain}')
ma = np.arange(1, min_mag - min_gain, min_gain * 2)
if len(ma) > num_gen:
mags = np.append(np.linspace(1, min_mag - min_gain, num_gen),
np.linspace(min_mag, max_mag, num_gen))
elif len(ma) == num_gen:
mags = np.append(ma, np.linspace(min_mag, max_mag, num_gen))
else:
mags = np.linspace(1, max_mag, num_gen * 2)
im = Image.open(fip)
im_raw = vtrans.ToTensor()(im)
for i, mag in enumerate(mags):
im = im_raw * mag
im.clamp_max_(1.)
fop = os.path.join(fod, f'{fn}_{i}.{ext}')
if not os.path.exists(fop):
vtrans.ToPILImage()(im).save(fop)
if __name__ == '__main__':
# one needs to download it online
fid = './data/LOL/train/images'
fod = './data/LOL/train/images_aug'
os.makedirs(fod, exist_ok=True)
bar = tqdm.tqdm(os.listdir(fid))
for fn in bar:
fip = os.path.join(fid, fn)
main(fip, fod)