geekyrakshit's picture
added model loading function
246dd98
raw
history blame
5.8 kB
import os
import numpy as np
from PIL import Image
from typing import List
from datetime import datetime
from tensorflow import keras
from tensorflow.keras import optimizers, models
from wandb.keras import WandbCallback
from .dataloader import LowLightDataset
from .models import build_mirnet_model
from .losses import CharbonnierLoss
from ..commons import (
peak_signal_noise_ratio,
closest_number,
init_wandb,
download_lol_dataset,
)
class MIRNet:
def __init__(self, experiment_name=None, wandb_api_key=None) -> None:
self.experiment_name = experiment_name
if wandb_api_key is not None:
init_wandb("mirnet", experiment_name, wandb_api_key)
self.using_wandb = True
else:
self.using_wandb = False
def build_datasets(
self,
image_size: int = 256,
dataset_label: str = "lol",
apply_random_horizontal_flip: bool = True,
apply_random_vertical_flip: bool = True,
apply_random_rotation: bool = True,
val_split: float = 0.2,
batch_size: int = 16,
):
if dataset_label == "lol":
(self.low_images, self.enhanced_images), (
self.test_low_images,
self.test_enhanced_images,
) = download_lol_dataset()
self.data_loader = LowLightDataset(
image_size=image_size,
apply_random_horizontal_flip=apply_random_horizontal_flip,
apply_random_vertical_flip=apply_random_vertical_flip,
apply_random_rotation=apply_random_rotation,
)
(self.train_dataset, self.val_dataset) = self.data_loader.get_datasets(
low_light_images=self.low_images,
enhanced_images=self.enhanced_images,
val_split=val_split,
batch_size=batch_size,
)
def build_model(
self,
num_recursive_residual_groups: int = 3,
num_multi_scale_residual_blocks: int = 2,
channels: int = 64,
learning_rate: float = 1e-4,
epsilon: float = 1e-3,
):
self.model = build_mirnet_model(
num_rrg=num_recursive_residual_groups,
num_mrb=num_multi_scale_residual_blocks,
channels=channels,
)
self.model.compile(
optimizer=optimizers.Adam(learning_rate=learning_rate),
loss=CharbonnierLoss(epsilon=epsilon),
metrics=[peak_signal_noise_ratio],
)
def load_model(
self, filepath, custom_objects=None, compile=True, options=None
) -> None:
self.model = models.load_model(
filepath=filepath,
custom_objects=custom_objects,
compile=compile,
options=options,
)
def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
self.model.save_weights(
filepath, overwrite=overwrite, save_format=save_format, options=options
)
def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
self.model.load_weights(
filepath, by_name=by_name, skip_mismatch=skip_mismatch, options=options
)
def train(self, epochs: int):
log_dir = os.path.join(
self.experiment_name,
"logs",
datetime.now().strftime("%Y%m%d-%H%M%S"),
)
tensorboard_callback = keras.callbacks.TensorBoard(log_dir, histogram_freq=1)
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
os.path.join(self.experiment_name, "weights.h5"),
save_best_only=True,
save_weights_only=True,
)
reduce_lr_callback = keras.callbacks.ReduceLROnPlateau(
monitor="val_peak_signal_noise_ratio",
factor=0.5,
patience=5,
verbose=1,
min_delta=1e-7,
mode="max",
)
callbacks = [
tensorboard_callback,
model_checkpoint_callback,
reduce_lr_callback,
]
if self.using_wandb:
callbacks += [WandbCallback()]
history = self.model.fit(
self.train_dataset,
validation_data=self.val_dataset,
epochs=epochs,
callbacks=callbacks,
)
return history
def infer(
self,
original_image,
image_resize_factor: float = 1.0,
resize_output: bool = False,
):
width, height = original_image.size
target_width, target_height = (
closest_number(width // image_resize_factor, 4),
closest_number(height // image_resize_factor, 4),
)
original_image = original_image.resize(
(target_width, target_height), Image.ANTIALIAS
)
image = keras.preprocessing.image.img_to_array(original_image)
image = image.astype("float32") / 255.0
image = np.expand_dims(image, axis=0)
output = self.model.predict(image)
output_image = output[0] * 255.0
output_image = output_image.clip(0, 255)
output_image = output_image.reshape(
(np.shape(output_image)[0], np.shape(output_image)[1], 3)
)
output_image = Image.fromarray(np.uint8(output_image))
original_image = Image.fromarray(np.uint8(original_image))
if resize_output:
output_image = output_image.resize((width, height), Image.ANTIALIAS)
return output_image
def infer_from_file(
self,
original_image_file: str,
image_resize_factor: float = 1.0,
resize_output: bool = False,
):
original_image = Image.open(original_image_file)
return self.infer(original_image, image_resize_factor, resize_output)