#use this to make and train a MIRnet model import cv2 import random import numpy as np from glob import glob from PIL import Image, ImageOps import matplotlib.pyplot as plt import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers from google.colab import drive drive.mount('/content/gdrive') random.seed(10) IMAGE_SIZE = 128 BATCH_SIZE = 4 MAX_TRAIN_IMAGES = 300 def read_image(image_path): image = tf.io.read_file(image_path) image = tf.image.decode_png(image, channels=3) image.set_shape([None, None, 3]) image = tf.cast(image, dtype=tf.float32) / 255.0 return image def random_crop(low_image, enhanced_image): low_image_shape = tf.shape(low_image)[:2] low_w = tf.random.uniform( shape=(), maxval=low_image_shape[1] - IMAGE_SIZE + 1, dtype=tf.int32 ) low_h = tf.random.uniform( shape=(), maxval=low_image_shape[0] - IMAGE_SIZE + 1, dtype=tf.int32 ) enhanced_w = low_w enhanced_h = low_h low_image_cropped = low_image[ low_h : low_h + IMAGE_SIZE, low_w : low_w + IMAGE_SIZE ] enhanced_image_cropped = enhanced_image[ enhanced_h : enhanced_h + IMAGE_SIZE, enhanced_w : enhanced_w + IMAGE_SIZE ] return low_image_cropped, enhanced_image_cropped def load_data(low_light_image_path, enhanced_image_path): low_light_image = read_image(low_light_image_path) enhanced_image = read_image(enhanced_image_path) low_light_image, enhanced_image = random_crop(low_light_image, enhanced_image) return low_light_image, enhanced_image def get_dataset(low_light_images, enhanced_images): dataset = tf.data.Dataset.from_tensor_slices((low_light_images, enhanced_images)) dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE) dataset = dataset.batch(BATCH_SIZE, drop_remainder=True) return dataset train_low_light_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/our485/low/*"))[:MAX_TRAIN_IMAGES] train_enhanced_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/our485/high/*"))[:MAX_TRAIN_IMAGES] val_low_light_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/our485/low/*"))[MAX_TRAIN_IMAGES:] val_enhanced_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/our485/high/*"))[MAX_TRAIN_IMAGES:] test_low_light_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/eval15/low/*")) test_enhanced_images = sorted(glob("/content/gdrive/MyDrive/dataset/lol_dataset/eval15/high/*")) train_dataset = get_dataset(train_low_light_images, train_enhanced_images) val_dataset = get_dataset(val_low_light_images, val_enhanced_images) print("Train Dataset:", train_dataset) print("Val Dataset:", val_dataset) def selective_kernel_feature_fusion( multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3 ): channels = list(multi_scale_feature_1.shape)[-1] combined_feature = layers.Add()( [multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3] ) gap = layers.GlobalAveragePooling2D()(combined_feature) channel_wise_statistics = tf.reshape(gap, shape=(-1, 1, 1, channels)) compact_feature_representation = layers.Conv2D( filters=channels // 8, kernel_size=(1, 1), activation="relu" )(channel_wise_statistics) feature_descriptor_1 = layers.Conv2D( channels, kernel_size=(1, 1), activation="softmax" )(compact_feature_representation) feature_descriptor_2 = layers.Conv2D( channels, kernel_size=(1, 1), activation="softmax" )(compact_feature_representation) feature_descriptor_3 = layers.Conv2D( channels, kernel_size=(1, 1), activation="softmax" )(compact_feature_representation) feature_1 = multi_scale_feature_1 * feature_descriptor_1 feature_2 = multi_scale_feature_2 * feature_descriptor_2 feature_3 = multi_scale_feature_3 * feature_descriptor_3 aggregated_feature = layers.Add()([feature_1, feature_2, feature_3]) return aggregated_feature def spatial_attention_block(input_tensor): average_pooling = tf.reduce_max(input_tensor, axis=-1) average_pooling = tf.expand_dims(average_pooling, axis=-1) max_pooling = tf.reduce_mean(input_tensor, axis=-1) max_pooling = tf.expand_dims(max_pooling, axis=-1) concatenated = layers.Concatenate(axis=-1)([average_pooling, max_pooling]) feature_map = layers.Conv2D(1, kernel_size=(1, 1))(concatenated) feature_map = tf.nn.sigmoid(feature_map) return input_tensor * feature_map def channel_attention_block(input_tensor): channels = list(input_tensor.shape)[-1] average_pooling = layers.GlobalAveragePooling2D()(input_tensor) feature_descriptor = tf.reshape(average_pooling, shape=(-1, 1, 1, channels)) feature_activations = layers.Conv2D( filters=channels // 8, kernel_size=(1, 1), activation="relu" )(feature_descriptor) feature_activations = layers.Conv2D( filters=channels, kernel_size=(1, 1), activation="sigmoid" )(feature_activations) return input_tensor * feature_activations def dual_attention_unit_block(input_tensor): channels = list(input_tensor.shape)[-1] feature_map = layers.Conv2D( channels, kernel_size=(3, 3), padding="same", activation="relu" )(input_tensor) feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")( feature_map ) channel_attention = channel_attention_block(feature_map) spatial_attention = spatial_attention_block(feature_map) concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention]) concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation) return layers.Add()([input_tensor, concatenation]) # Recursive Residual Modules def down_sampling_module(input_tensor): channels = list(input_tensor.shape)[-1] main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")( input_tensor ) main_branch = layers.Conv2D( channels, kernel_size=(3, 3), padding="same", activation="relu" )(main_branch) main_branch = layers.MaxPooling2D()(main_branch) main_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(main_branch) skip_branch = layers.MaxPooling2D()(input_tensor) skip_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch) return layers.Add()([skip_branch, main_branch]) def up_sampling_module(input_tensor): channels = list(input_tensor.shape)[-1] main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")( input_tensor ) main_branch = layers.Conv2D( channels, kernel_size=(3, 3), padding="same", activation="relu" )(main_branch) main_branch = layers.UpSampling2D()(main_branch) main_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(main_branch) skip_branch = layers.UpSampling2D()(input_tensor) skip_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch) return layers.Add()([skip_branch, main_branch]) # MRB Block def multi_scale_residual_block(input_tensor, channels): # features level1 = input_tensor level2 = down_sampling_module(input_tensor) level3 = down_sampling_module(level2) # DAU level1_dau = dual_attention_unit_block(level1) level2_dau = dual_attention_unit_block(level2) level3_dau = dual_attention_unit_block(level3) # SKFF level1_skff = selective_kernel_feature_fusion( level1_dau, up_sampling_module(level2_dau), up_sampling_module(up_sampling_module(level3_dau)), ) level2_skff = selective_kernel_feature_fusion( down_sampling_module(level1_dau), level2_dau, up_sampling_module(level3_dau) ) level3_skff = selective_kernel_feature_fusion( down_sampling_module(down_sampling_module(level1_dau)), down_sampling_module(level2_dau), level3_dau, ) # DAU 2 level1_dau_2 = dual_attention_unit_block(level1_skff) level2_dau_2 = up_sampling_module((dual_attention_unit_block(level2_skff))) level3_dau_2 = up_sampling_module( up_sampling_module(dual_attention_unit_block(level3_skff)) ) # SKFF 2 skff_ = selective_kernel_feature_fusion(level1_dau_2, level2_dau_2, level3_dau_2) conv = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(skff_) return layers.Add()([input_tensor, conv]) def recursive_residual_group(input_tensor, num_mrb, channels): conv1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor) for _ in range(num_mrb): conv1 = multi_scale_residual_block(conv1, channels) conv2 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(conv1) return layers.Add()([conv2, input_tensor]) def mirnet_model(num_rrg, num_mrb, channels): input_tensor = keras.Input(shape=[None, None, 3]) x1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor) for _ in range(num_rrg): x1 = recursive_residual_group(x1, num_mrb, channels) conv = layers.Conv2D(3, kernel_size=(3, 3), padding="same")(x1) output_tensor = layers.Add()([input_tensor, conv]) return keras.Model(input_tensor, output_tensor) model = mirnet_model(num_rrg=3, num_mrb=2, channels=64) def charbonnier_loss(y_true, y_pred): return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3))) def peak_signal_noise_ratio(y_true, y_pred): return tf.image.psnr(y_pred, y_true, max_val=255.0) optimizer = keras.optimizers.Adam(learning_rate=1e-4) model.compile( optimizer=optimizer, loss=charbonnier_loss, metrics=[peak_signal_noise_ratio] ) history = model.fit( train_dataset, validation_data=val_dataset, #epochs traning cycles set krna k lia epochs=1, callbacks=[ keras.callbacks.ReduceLROnPlateau( monitor="val_peak_signal_noise_ratio", factor=0.5, patience=5, verbose=1, min_delta=1e-7, mode="max", ) ], ) plt.plot(history.history["loss"], label="train_loss") plt.plot(history.history["val_loss"], label="val_loss") plt.xlabel("Epochs") plt.ylabel("Loss") plt.title("Train and Validation Losses Over Epochs", fontsize=14) plt.legend() plt.grid() plt.show() plt.plot(history.history["peak_signal_noise_ratio"], label="train_psnr") plt.plot(history.history["val_peak_signal_noise_ratio"], label="val_psnr") plt.xlabel("Epochs") plt.ylabel("PSNR") plt.title("Train and Validation PSNR Over Epochs", fontsize=14) plt.legend() plt.grid() plt.show() def plot_results(images, titles, figure_size=(12, 12)): fig = plt.figure(figsize=figure_size) for i in range(len(images)): fig.add_subplot(1, len(images), i + 1).set_title(titles[i]) _ = plt.imshow(images[i]) plt.axis("off") plt.show() def infer(original_image): image = keras.preprocessing.image.img_to_array(original_image) image = image.astype("float16") / 255.0 image = np.expand_dims(image, axis=0) output = model.predict(image) output_image = output[0] * 255.0 output_image = output_image.clip(0, 255) output_image = output_image.reshape( (np.shape(output_image)[0], np.shape(output_image)[1], 3) ) output_image = Image.fromarray(np.uint8(output_image)) original_image = Image.fromarray(np.uint8(original_image)) return output_image for low_light_image in random.sample(test_low_light_images, 2): original_image = Image.open(low_light_image) enhanced_image = infer(original_image) plot_results( [original_image, ImageOps.autocontrast(original_image), enhanced_image], ["Original", "PIL Autocontrast", "MIRNet Enhanced"], (20, 12), )