import streamlit as st import numpy as np import cv2 from PIL import Image import io import time from streamlit_drawable_canvas import st_canvas # Helper functions def np_to_pil(np_img_bgr): if len(np_img_bgr.shape) == 2: return Image.fromarray(np_img_bgr) else: return Image.fromarray(np_img_bgr[..., ::-1]) def pil_to_np(pil_img): np_img_rgb = np.array(pil_img) if np_img_rgb.shape[-1] == 4: np_img_rgb = np_img_rgb[..., :3] return np_img_rgb[..., ::-1] def download_button_img(np_img_bgr, label, filename): img = np_to_pil(np_img_bgr) buf = io.BytesIO() img.save(buf, format="PNG") st.download_button(label, data=buf.getvalue(), file_name=filename, mime="image/png") # Set page config st.set_page_config(page_title="Image Restoration App", layout="wide") st.title("Image Restoration App") # Upload section st.sidebar.title("Upload Image") uploaded_file = st.sidebar.file_uploader("Choose an image", type=["png", "jpg", "jpeg"]) if "orig_image" not in st.session_state: st.session_state.orig_image = None if "current_image" not in st.session_state: st.session_state.current_image = None if "inpaint_result" not in st.session_state: st.session_state.inpaint_result = None if "canvas_result" not in st.session_state: st.session_state.canvas_result = None if uploaded_file: file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) image = cv2.imdecode(file_bytes, 1) st.session_state.orig_image = image st.session_state.current_image = image.copy() st.session_state.inpaint_result = None if st.session_state.orig_image is None: st.info("Upload an image to get started.") st.stop() # Tabs tabs = st.tabs(["Filters", "Inpainting", "Compare"]) # FILTERS TAB with tabs[0]: col1, col2 = st.columns([1, 2]) with col1: st.subheader("Filters") filter_type = st.selectbox( "Choose filter:", ["None", "Gaussian", "Median", "Bilateral", "Brightness/Contrast", "Grayscale"], key="filter", ) if filter_type == "Gaussian": ksize = st.slider("Kernel Size", 1, 31, 5, step=2, key="gauss_ksize") sigma = st.slider("Sigma X", 0.0, 10.0, 2.0, key="gauss_sigma") elif filter_type == "Median": ksize = st.slider("Kernel Size", 1, 31, 5, step=2, key="median_ksize") elif filter_type == "Bilateral": d = st.slider("Diameter", 1, 30, 9, key="bilateral_d") sigmaColor = st.slider("Sigma Color", 1, 150, 75, key="bilateral_color") sigmaSpace = st.slider("Sigma Space", 1, 150, 75, key="bilateral_space") elif filter_type == "Brightness/Contrast": brightness = st.slider("Brightness", -100, 100, 0, key="brightness") contrast = st.slider("Contrast", -100, 100, 0, key="contrast") if st.button("Apply Filter", key="apply_filter"): img = st.session_state.current_image.copy() if filter_type == "Gaussian": img = cv2.GaussianBlur(img, (ksize, ksize), sigma) elif filter_type == "Median": img = cv2.medianBlur(img, ksize) elif filter_type == "Bilateral": img = cv2.bilateralFilter(img, d, sigmaColor, sigmaSpace) elif filter_type == "Brightness/Contrast": img = cv2.convertScaleAbs(img, alpha=1 + contrast / 100.0, beta=brightness) elif filter_type == "Grayscale": img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) st.session_state.current_image = img st.session_state.inpaint_result = None if st.button("Reset Image", key="reset_filter"): st.session_state.current_image = st.session_state.orig_image.copy() st.session_state.inpaint_result = None with col2: st.subheader("Image Preview") img = st.session_state.current_image st.image(img if len(img.shape) == 2 else img[..., ::-1], use_container_width=True) # INPAINTING TAB with tabs[1]: col1, col2, col3 = st.columns([1, 1.5, 1.5]) with col1: st.subheader("Inpainting Settings") stroke_width = st.slider("Stroke Width", 1, 25, 5, key="stroke") method = st.selectbox("Inpainting Method", ["Telea", "NS"], key="inpaint_method") if st.button("Apply Inpaint", key="apply_inpaint"): canvas = st.session_state.get("canvas_result") if canvas and canvas.image_data is not None: mask_rgba = canvas.image_data if mask_rgba.shape[-1] == 4: mask = mask_rgba[..., 3] h, w = st.session_state.current_image.shape[:2] mask = cv2.resize(mask, (w, h)) mask = (mask > 0).astype(np.uint8) * 255 flag = cv2.INPAINT_TELEA if method == "Telea" else cv2.INPAINT_NS result = cv2.inpaint(st.session_state.current_image, mask, 3, flag) st.session_state.inpaint_result = result if st.button("Reset to Original", key="reset_inpaint"): st.session_state.current_image = st.session_state.orig_image.copy() st.session_state.inpaint_result = None st.markdown("---") if st.button("Reset Canvas"): st.session_state.canvas_key = f"canvas_{int(time.time())}" with col2: st.subheader("Draw Mask") h, w = st.session_state.current_image.shape[:2] max_width = 500 scale = min(1.0, max_width / w) canvas_w, canvas_h = int(w * scale), int(h * scale) show_mask = st.checkbox("Show Mask Preview", key="show_mask") if "canvas_key" not in st.session_state: st.session_state.canvas_key = "canvas" if not show_mask: pil_bg = np_to_pil(st.session_state.current_image).resize((canvas_w, canvas_h)) canvas = st_canvas( fill_color="white", stroke_width=stroke_width, stroke_color="black", background_image=pil_bg, update_streamlit=True, height=canvas_h, width=canvas_w, drawing_mode="freedraw", key=st.session_state.canvas_key, ) st.session_state.canvas_result = canvas else: canvas = st.session_state.get("canvas_result") if canvas and canvas.image_data is not None: mask = canvas.image_data[..., 3] if canvas.image_data.shape[-1] == 4 else None if mask is not None: mask = cv2.resize(mask, (w, h)) mask = (mask > 0).astype(np.uint8) * 255 st.image(mask, caption="Inpainting Mask", use_container_width=True) with col3: st.subheader("Inpainting Result") result = st.session_state.inpaint_result if result is not None: st.image(result[..., ::-1], use_container_width=True) download_button_img(result, "Download Inpainted Image", "inpainted_result.png") else: st.info("Draw a mask and apply inpainting to see result.") # COMPARE TAB with tabs[2]: col1, col2 = st.columns(2) with col1: st.subheader("Original Image") orig = st.session_state.orig_image st.image(orig[..., ::-1], use_container_width=True) download_button_img(orig, "Download Original", "original.png") with col2: st.subheader("Processed Image") current = ( st.session_state.inpaint_result if st.session_state.inpaint_result is not None else st.session_state.current_image ) st.image(current if len(current.shape) == 2 else current[..., ::-1], use_container_width=True) download_button_img(current, "Download Current", "current.png")