"Open

In [None]:
!git clone https://github.com/soumik12345/enhance-me
!pip install -qqq wandb streamlit

In [None]:
import os
import sys

sys.path.append("..")

from PIL import Image
from enhance_me import commons
from enhance_me.mirnet import MIRNet
from enhance_me.zero_dce import ZeroDCE

In [None]:
# @title MIRNet Train Configs

experiment_name = "lol_dataset_256" # @param {type:"string"}
image_size = 128 # @param {type:"integer"}
dataset_label = "lol" # @param ["lol"]
apply_random_horizontal_flip = True # @param {type:"boolean"}
apply_random_vertical_flip = True # @param {type:"boolean"}
apply_random_rotation = True # @param {type:"boolean"}
use_mixed_precision = True # @param {type:"boolean"}
wandb_api_key = "" # @param {type:"string"}
val_split = 0.1 # @param {type:"slider", min:0.1, max:1.0, step:0.1}
batch_size = 4 # @param {type:"integer"}
num_recursive_residual_groups = 3 # @param {type:"slider", min:1, max:5, step:1}
num_multi_scale_residual_blocks = 2 # @param {type:"slider", min:1, max:5, step:1}
learning_rate = 1e-4 # @param {type:"number"}
epsilon = 1e-3 # @param {type:"number"}
epochs = 50 # @param {type:"slider", min:10, max:100, step:5}

In [None]:
mirnet = MIRNet(
 experiment_name=experiment_name,
 wandb_api_key=None if wandb_api_key == "" else wandb_api_key,
)

In [None]:
mirnet.build_datasets(
 image_size=image_size,
 dataset_label=dataset_label,
 apply_random_horizontal_flip=apply_random_horizontal_flip,
 apply_random_vertical_flip=apply_random_vertical_flip,
 apply_random_rotation=apply_random_rotation,
 val_split=val_split,
 batch_size=batch_size,
)

In [None]:
mirnet.build_model(
 use_mixed_precision=use_mixed_precision,
 num_recursive_residual_groups=num_recursive_residual_groups,
 num_multi_scale_residual_blocks=num_multi_scale_residual_blocks,
 learning_rate=learning_rate,
 epsilon=epsilon,
)

In [None]:
history = mirnet.train(epochs=epochs)

In [None]:
mirnet.load_weights(os.path.join(mirnet.experiment_name, "weights.h5"))

In [None]:
for index, low_image_file in enumerate(mirnet.test_low_images):
 original_image = Image.open(low_image_file)
 enhanced_image = mirnet.infer(original_image)
 ground_truth = Image.open(mirnet.test_enhanced_images[index])
 commons.plot_results(
 [original_image, ground_truth, enhanced_image],
 ["Original Image", "Ground Truth", "Enhanced Image"],
 (18, 18),
 )

In [None]:
# @title Zero-DCE Train Configs

experiment_name = "unpaired_low_light_256_resize" # @param {type:"string"}
image_size = 256 # @param {type:"integer"}
dataset_label = "unpaired" # @param ["lol", "unpaired"]
use_mixed_precision = False # @param {type:"boolean"}
apply_resize = True # @param {type:"boolean"}
apply_random_horizontal_flip = True # @param {type:"boolean"}
apply_random_vertical_flip = True # @param {type:"boolean"}
apply_random_rotation = True # @param {type:"boolean"}
wandb_api_key = "" # @param {type:"string"}
val_split = 0.1 # @param {type:"slider", min:0.1, max:1.0, step:0.1}
batch_size = 16 # @param {type:"integer"}
learning_rate = 1e-4 # @param {type:"number"}
epsilon = 1e-3 # @param {type:"number"}
epochs = 100 # @param {type:"slider", min:10, max:100, step:5}

In [None]:
zero_dce = ZeroDCE(
 experiment_name=experiment_name,
 wandb_api_key=None if wandb_api_key == "" else wandb_api_key,
 use_mixed_precision=use_mixed_precision
)

In [None]:
zero_dce.build_datasets(
 image_size=image_size,
 dataset_label=dataset_label,
 apply_resize=apply_resize,
 apply_random_horizontal_flip=apply_random_horizontal_flip,
 apply_random_vertical_flip=apply_random_vertical_flip,
 apply_random_rotation=apply_random_rotation,
 val_split=val_split,
 batch_size=batch_size
)

In [None]:
zero_dce.compile(learning_rate=learning_rate)
history = zero_dce.train(epochs=epochs)
zero_dce.save_weights(os.path.join(experiment_name, "weights.h5"))

In [None]:
for index, low_image_file in enumerate(zero_dce.test_low_images):
 original_image = Image.open(low_image_file)
 enhanced_image = zero_dce.infer(original_image)
 commons.plot_results(
 [original_image, enhanced_image],
 ["Original Image", "Enhanced Image"],
 (18, 18),
 )