import streamlit as st from PIL import Image import io import sys, os import torch import torchvision.transforms as T import torchvision.utils as vutils import base64 import torch import torchvision.transforms as T from PIL import Image from huggingface_hub import hf_hub_download from model.MIRNet.model import MIRNet from model.MIRNet.model import MIRNet def run_model(input_image): device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") ) model = MIRNet(num_features=64).to(device) model_path = hf_hub_download( repo_id="dblasko/mirnet-low-light-img-enhancement", filename="mirnet_finetuned.pth", ) model.load_state_dict( torch.load(model_path, map_location=device)["model_state_dict"] ) model.eval() with torch.no_grad(): img = input_image img_tensor = T.Compose( [ T.Resize(400), T.ToTensor(), T.Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), ] )(img).unsqueeze(0) img_tensor = img_tensor.to(device) if img_tensor.shape[2] % 8 != 0: img_tensor = img_tensor[:, :, : -(img_tensor.shape[2] % 8), :] if img_tensor.shape[3] % 8 != 0: img_tensor = img_tensor[:, :, :, : -(img_tensor.shape[3] % 8)] output = model(img_tensor) vutils.save_image(output, open(f"temp.png", "wb")) output_image = Image.open("temp.png") os.remove("temp.png") return output_image def get_base64_font(font_path): with open(font_path, "rb") as font_file: return base64.b64encode(font_file.read()).decode() st.set_page_config(layout="wide") font_name = "Gloock" gloock_b64 = get_base64_font("utils/assets/Gloock-Regular.ttf") font_name_text = "Merriweather sans" merri_b64 = get_base64_font("utils/assets/MerriweatherSans-Regular.ttf") hide_streamlit_style = f""" """ st.markdown(hide_streamlit_style, unsafe_allow_html=True) st.title("Low-light event-image enhancement with MIRNet.") # File uploader widget uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # To read file as bytes: bytes_data = uploaded_file.getvalue() image = Image.open(io.BytesIO(bytes_data)).convert("RGB") # Create two columns for images col1, col2 = st.columns(2) with col1: st.image(image, caption="Original Image", use_column_width="always") # Button to enhance image if st.button("Enhance Image"): with col2: # Assume your model has a function 'enhance' to enhance the image enhanced_image = run_model(image) st.image( enhanced_image, caption="Enhanced Image", use_column_width="always" ) # Download button buf = io.BytesIO() enhanced_image.save(buf, format="JPEG") byte_im = buf.getvalue() st.download_button( label="Download image", data=byte_im, file_name="enhanced_image.jpg", mime="image/jpeg", )