Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import cv2 | |
from torchvision.transforms import Compose, ToTensor, Resize, Normalize, ConvertImageDtype | |
from PIL import Image | |
import numpy as np | |
import gradio as gr | |
from model import IAT # Ensure the correct import path | |
def set_example_image(example: list) -> dict: | |
return gr.Image.update(value=example[0]) | |
def tensor_to_numpy(tensor): | |
print("Converting tensor to numpy array...") | |
tensor = tensor.detach().cpu().numpy() | |
if tensor.ndim == 3 and tensor.shape[0] == 3: # Convert CHW to HWC | |
tensor = tensor.transpose(1, 2, 0) | |
tensor = np.clip(tensor * 255, 0, 255).astype(np.uint8) # Ensure the output is uint8 | |
return tensor | |
def dark_inference(img): | |
print("Starting dark inference...") | |
model = IAT() | |
checkpoint_file_path = './checkpoint/best_Epoch_lol.pth' | |
state_dict = torch.load(checkpoint_file_path, map_location='cpu') | |
model.load_state_dict(state_dict) | |
model.eval() | |
print(f'Load model from {checkpoint_file_path}') | |
transform = Compose([ | |
ToTensor(), | |
Resize(384), | |
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
ConvertImageDtype(torch.float) | |
]) | |
input_img = transform(img) | |
print(f'Image shape after transform: {input_img.shape}') | |
with torch.no_grad(): | |
enhanced_img = model(input_img.unsqueeze(0)) | |
result_img = tensor_to_numpy(enhanced_img[0]) | |
print("Dark inference completed.") | |
return result_img | |
def exposure_inference(img): | |
print("Starting exposure inference...") | |
model = IAT() | |
checkpoint_file_path = './checkpoint/best_Epoch_exposure.pth' | |
state_dict = torch.load(checkpoint_file_path, map_location='cpu') | |
model.load_state_dict(state_dict) | |
model.eval() | |
print(f'Load model from {checkpoint_file_path}') | |
transform = Compose([ | |
ToTensor(), | |
Resize(384), | |
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
ConvertImageDtype(torch.float) | |
]) | |
input_img = transform(img) | |
print(f'Image shape after transform: {input_img.shape}') | |
with torch.no_grad(): | |
enhanced_img = model(input_img.unsqueeze(0)) | |
result_img = tensor_to_numpy(enhanced_img[0]) | |
print("Exposure inference completed.") | |
return result_img | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown( | |
""" | |
# IAT | |
Gradio demo for <a href='https://github.com/cuiziteng/Illumination-Adaptive-Transformer' target='_blank'>IAT</a>: To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. | |
""" | |
) | |
with gr.Box(): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
input_image = gr.Image(label='Input Image', type='numpy') | |
with gr.Row(): | |
dark_button = gr.Button('Low-light Enhancement') | |
with gr.Row(): | |
exposure_button = gr.Button('Exposure Correction') | |
with gr.Column(): | |
res_image = gr.Image(type='numpy', label='Results') | |
with gr.Row(): | |
dark_example_images = gr.Dataset( | |
components=[input_image], | |
samples=[['dark_imgs/1.jpg'], ['dark_imgs/2.jpg'], ['dark_imgs/3.jpg']] | |
) | |
with gr.Row(): | |
exposure_example_images = gr.Dataset( | |
components=[input_image], | |
samples=[['exposure_imgs/1.jpg'], ['exposure_imgs/2.jpg'], ['exposure_imgs/3.jpeg']] | |
) | |
gr.Markdown( | |
""" | |
<p style='text-align: center'><a href='https://arxiv.org/abs/2205.14871' target='_blank'>You Only Need 90K Parameters to Adapt Light: A Light Weight Transformer for Image Enhancement and Exposure Correction</a> | <a href='https://github.com/cuiziteng/Illumination-Adaptive-Transformer' target='_blank'>Github Repo</a></p> | |
""" | |
) | |
dark_button.click(fn=dark_inference, inputs=input_image, outputs=res_image) | |
exposure_button.click(fn=exposure_inference, inputs=input_image, outputs=res_image) | |
dark_example_images.click(fn=set_example_image, inputs=dark_example_images, outputs=dark_example_images.components) | |
exposure_example_images.click(fn=set_example_image, inputs=exposure_example_images, outputs=exposure_example_images.components) | |
demo.launch(enable_queue=True) | |