IAT_enhancement / app.py
atlury's picture
Update app.py
2b3e4a6 verified
import os
import torch
import cv2
from torchvision.transforms import Compose, ToTensor, Resize, Normalize, ConvertImageDtype
from PIL import Image
import numpy as np
import gradio as gr
from model import IAT # Ensure the correct import path
def set_example_image(example: list) -> dict:
return gr.Image.update(value=example[0])
def tensor_to_numpy(tensor):
print("Converting tensor to numpy array...")
tensor = tensor.detach().cpu().numpy()
if tensor.ndim == 3 and tensor.shape[0] == 3: # Convert CHW to HWC
tensor = tensor.transpose(1, 2, 0)
tensor = np.clip(tensor * 255, 0, 255).astype(np.uint8) # Ensure the output is uint8
return tensor
def dark_inference(img):
print("Starting dark inference...")
model = IAT()
checkpoint_file_path = './checkpoint/best_Epoch_lol.pth'
state_dict = torch.load(checkpoint_file_path, map_location='cpu')
model.load_state_dict(state_dict)
model.eval()
print(f'Load model from {checkpoint_file_path}')
transform = Compose([
ToTensor(),
Resize(384),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
ConvertImageDtype(torch.float)
])
input_img = transform(img)
print(f'Image shape after transform: {input_img.shape}')
with torch.no_grad():
enhanced_img = model(input_img.unsqueeze(0))
result_img = tensor_to_numpy(enhanced_img[0])
print("Dark inference completed.")
return result_img
def exposure_inference(img):
print("Starting exposure inference...")
model = IAT()
checkpoint_file_path = './checkpoint/best_Epoch_exposure.pth'
state_dict = torch.load(checkpoint_file_path, map_location='cpu')
model.load_state_dict(state_dict)
model.eval()
print(f'Load model from {checkpoint_file_path}')
transform = Compose([
ToTensor(),
Resize(384),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
ConvertImageDtype(torch.float)
])
input_img = transform(img)
print(f'Image shape after transform: {input_img.shape}')
with torch.no_grad():
enhanced_img = model(input_img.unsqueeze(0))
result_img = tensor_to_numpy(enhanced_img[0])
print("Exposure inference completed.")
return result_img
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
# IAT
Gradio demo for <a href='https://github.com/cuiziteng/Illumination-Adaptive-Transformer' target='_blank'>IAT</a>: To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.
"""
)
with gr.Box():
with gr.Row():
with gr.Column():
with gr.Row():
input_image = gr.Image(label='Input Image', type='numpy')
with gr.Row():
dark_button = gr.Button('Low-light Enhancement')
with gr.Row():
exposure_button = gr.Button('Exposure Correction')
with gr.Column():
res_image = gr.Image(type='numpy', label='Results')
with gr.Row():
dark_example_images = gr.Dataset(
components=[input_image],
samples=[['dark_imgs/1.jpg'], ['dark_imgs/2.jpg'], ['dark_imgs/3.jpg']]
)
with gr.Row():
exposure_example_images = gr.Dataset(
components=[input_image],
samples=[['exposure_imgs/1.jpg'], ['exposure_imgs/2.jpg'], ['exposure_imgs/3.jpeg']]
)
gr.Markdown(
"""
<p style='text-align: center'><a href='https://arxiv.org/abs/2205.14871' target='_blank'>You Only Need 90K Parameters to Adapt Light: A Light Weight Transformer for Image Enhancement and Exposure Correction</a> | <a href='https://github.com/cuiziteng/Illumination-Adaptive-Transformer' target='_blank'>Github Repo</a></p>
"""
)
dark_button.click(fn=dark_inference, inputs=input_image, outputs=res_image)
exposure_button.click(fn=exposure_inference, inputs=input_image, outputs=res_image)
dark_example_images.click(fn=set_example_image, inputs=dark_example_images, outputs=dark_example_images.components)
exposure_example_images.click(fn=set_example_image, inputs=exposure_example_images, outputs=exposure_example_images.components)
demo.launch(enable_queue=True)