import torch import torch.optim import model import numpy as np from PIL import Image import streamlit as st from torchvision import transforms scale_factor = 1 @st.cache def load_model() -> torch.nn.Module: DCE_net = model.enhance_net_nopool(scale_factor) DCE_net.load_state_dict(torch.load("lowlight-dce-snapshot.pth", map_location=torch.device('cpu'))) return DCE_net def fix_lowlight(image: Image.Image) -> Image.Image: DCE_net = load_model() data_lowlight = np.asarray(image) / 255.0 data_lowlight = torch.from_numpy(data_lowlight).float() h = (data_lowlight.shape[0] // scale_factor) * scale_factor w = (data_lowlight.shape[1] // scale_factor) * scale_factor data_lowlight = data_lowlight[0:h, 0:w, :] data_lowlight = data_lowlight.permute(2, 0, 1) data_lowlight = data_lowlight.unsqueeze(0) enhanced_image, _ = DCE_net(data_lowlight) im = transforms.ToPILImage()(enhanced_image[0]).convert("RGB") return im def main(): st.title("Lowlight Enhancement") st.write("This is a simple lowlight enhancement app with great performance and does not require paired images to train.") st.write("The model runs at 1000/11 FPS on single GPU/CPU on images with a size of 1200*900*3") uploaded_file = st.file_uploader("Lowlight Image") if uploaded_file: data_lowlight = Image.open(uploaded_file).convert('RGB') col1, col2 = st.columns(2) col1.write("Original (Lowlight)") col1.image(data_lowlight, caption="Lowlight Image", use_column_width=True) col2.write("Enhanced") with st.spinner('🧠 Enhancing...'): fixed_img = fix_lowlight(data_lowlight) col2.image(fixed_img, caption="Enhanced Image", use_column_width=True) main()