{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "\"Open" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1JryaVhtBHij", "outputId": "97ee6a4a-2479-4124-e96a-f0a792bdec46" }, "outputs": [], "source": [ "!git clone https://github.com/soumik12345/enhance-me\n", "!pip install -qqq wandb streamlit" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G_c4VtXWHR5l" }, "outputs": [], "source": [ "import os\n", "import sys\n", "\n", "sys.path.append(\"..\")\n", "\n", "from PIL import Image\n", "from enhance_me import commons\n", "from enhance_me.mirnet import MIRNet\n", "from enhance_me.zero_dce import ZeroDCE" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZpBHbYaMIqP_" }, "outputs": [], "source": [ "# @title MIRNet Train Configs\n", "\n", "experiment_name = \"lol_dataset_256\" # @param {type:\"string\"}\n", "image_size = 128 # @param {type:\"integer\"}\n", "dataset_label = \"lol\" # @param [\"lol\"]\n", "apply_random_horizontal_flip = True # @param {type:\"boolean\"}\n", "apply_random_vertical_flip = True # @param {type:\"boolean\"}\n", "apply_random_rotation = True # @param {type:\"boolean\"}\n", "use_mixed_precision = True # @param {type:\"boolean\"}\n", "wandb_api_key = \"\" # @param {type:\"string\"}\n", "val_split = 0.1 # @param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n", "batch_size = 4 # @param {type:\"integer\"}\n", "num_recursive_residual_groups = 3 # @param {type:\"slider\", min:1, max:5, step:1}\n", "num_multi_scale_residual_blocks = 2 # @param {type:\"slider\", min:1, max:5, step:1}\n", "learning_rate = 1e-4 # @param {type:\"number\"}\n", "epsilon = 1e-3 # @param {type:\"number\"}\n", "epochs = 50 # @param {type:\"slider\", min:10, max:100, step:5}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 52 }, "id": "IVRoedqBIMuH", "outputId": "53ca5beb-871a-4ec3-b757-173e09a15331" }, "outputs": [], "source": [ "mirnet = MIRNet(\n", " experiment_name=experiment_name,\n", " wandb_api_key=None if wandb_api_key == \"\" else wandb_api_key,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "O66Iwzx8IsGh", "outputId": "0b6f1683-65d1-4737-a32f-d36b331d2bc2" }, "outputs": [], "source": [ "mirnet.build_datasets(\n", " image_size=image_size,\n", " dataset_label=dataset_label,\n", " apply_random_horizontal_flip=apply_random_horizontal_flip,\n", " apply_random_vertical_flip=apply_random_vertical_flip,\n", " apply_random_rotation=apply_random_rotation,\n", " val_split=val_split,\n", " batch_size=batch_size,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tsfKrBCsL_Bb" }, "outputs": [], "source": [ "mirnet.build_model(\n", " use_mixed_precision=use_mixed_precision,\n", " num_recursive_residual_groups=num_recursive_residual_groups,\n", " num_multi_scale_residual_blocks=num_multi_scale_residual_blocks,\n", " learning_rate=learning_rate,\n", " epsilon=epsilon,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "y3L9wlpkNziL", "outputId": "5149f0e7-91f4-450f-c43a-1b6028692bbc" }, "outputs": [], "source": [ "history = mirnet.train(epochs=epochs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mirnet.load_weights(os.path.join(mirnet.experiment_name, \"weights.h5\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "background_save": true }, "id": "daFKbgBkiyzc" }, "outputs": [], "source": [ "for index, low_image_file in enumerate(mirnet.test_low_images):\n", " original_image = Image.open(low_image_file)\n", " enhanced_image = mirnet.infer(original_image)\n", " ground_truth = Image.open(mirnet.test_enhanced_images[index])\n", " commons.plot_results(\n", " [original_image, ground_truth, enhanced_image],\n", " [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n", " (18, 18),\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dO-IbNQHkB3R" }, "outputs": [], "source": [ "# @title Zero-DCE Train Configs\n", "\n", "experiment_name = \"unpaired_low_light_256_resize\" # @param {type:\"string\"}\n", "image_size = 256 # @param {type:\"integer\"}\n", "dataset_label = \"unpaired\" # @param [\"lol\", \"unpaired\"]\n", "use_mixed_precision = False # @param {type:\"boolean\"}\n", "apply_resize = True # @param {type:\"boolean\"}\n", "apply_random_horizontal_flip = True # @param {type:\"boolean\"}\n", "apply_random_vertical_flip = True # @param {type:\"boolean\"}\n", "apply_random_rotation = True # @param {type:\"boolean\"}\n", "wandb_api_key = \"\" # @param {type:\"string\"}\n", "val_split = 0.1 # @param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n", "batch_size = 16 # @param {type:\"integer\"}\n", "learning_rate = 1e-4 # @param {type:\"number\"}\n", "epsilon = 1e-3 # @param {type:\"number\"}\n", "epochs = 100 # @param {type:\"slider\", min:10, max:100, step:5}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "zero_dce = ZeroDCE(\n", " experiment_name=experiment_name,\n", " wandb_api_key=None if wandb_api_key == \"\" else wandb_api_key,\n", " use_mixed_precision=use_mixed_precision\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "zero_dce.build_datasets(\n", " image_size=image_size,\n", " dataset_label=dataset_label,\n", " apply_resize=apply_resize,\n", " apply_random_horizontal_flip=apply_random_horizontal_flip,\n", " apply_random_vertical_flip=apply_random_vertical_flip,\n", " apply_random_rotation=apply_random_rotation,\n", " val_split=val_split,\n", " batch_size=batch_size\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "zero_dce.compile(learning_rate=learning_rate)\n", "history = zero_dce.train(epochs=epochs)\n", "zero_dce.save_weights(os.path.join(experiment_name, \"weights.h5\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "for index, low_image_file in enumerate(zero_dce.test_low_images):\n", " original_image = Image.open(low_image_file)\n", " enhanced_image = zero_dce.infer(original_image)\n", " commons.plot_results(\n", " [original_image, enhanced_image],\n", " [\"Original Image\", \"Enhanced Image\"],\n", " (18, 18),\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "authorship_tag": "ABX9TyN4LuJh6kWhbqxzA5s9sp7k", "collapsed_sections": [], "include_colab_link": true, "machine_shape": "hm", "name": "enhance-me-train.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 1 }