lunde's picture
Convert to RGB to not fail on alpha-channels
5ae48fc
import torch
import torch.optim
import model
import numpy as np
from PIL import Image
import streamlit as st
from torchvision import transforms
scale_factor = 1
@st.cache
def load_model() -> torch.nn.Module:
DCE_net = model.enhance_net_nopool(scale_factor)
DCE_net.load_state_dict(torch.load("lowlight-dce-snapshot.pth", map_location=torch.device('cpu')))
return DCE_net
def fix_lowlight(image: Image.Image) -> Image.Image:
DCE_net = load_model()
data_lowlight = np.asarray(image) / 255.0
data_lowlight = torch.from_numpy(data_lowlight).float()
h = (data_lowlight.shape[0] // scale_factor) * scale_factor
w = (data_lowlight.shape[1] // scale_factor) * scale_factor
data_lowlight = data_lowlight[0:h, 0:w, :]
data_lowlight = data_lowlight.permute(2, 0, 1)
data_lowlight = data_lowlight.unsqueeze(0)
enhanced_image, _ = DCE_net(data_lowlight)
im = transforms.ToPILImage()(enhanced_image[0]).convert("RGB")
return im
def main():
st.title("Lowlight Enhancement")
st.write("This is a simple lowlight enhancement app with great performance and does not require paired images to train.")
st.write("The model runs at 1000/11 FPS on single GPU/CPU on images with a size of 1200*900*3")
uploaded_file = st.file_uploader("Lowlight Image")
if uploaded_file:
data_lowlight = Image.open(uploaded_file).convert('RGB')
col1, col2 = st.columns(2)
col1.write("Original (Lowlight)")
col1.image(data_lowlight, caption="Lowlight Image", use_column_width=True)
col2.write("Enhanced")
with st.spinner('🧠 Enhancing...'):
fixed_img = fix_lowlight(data_lowlight)
col2.image(fixed_img, caption="Enhanced Image", use_column_width=True)
main()