ankanpy's picture
updated app.py
d60ad95
raw
history blame
7.93 kB
import streamlit as st
import numpy as np
import cv2
from PIL import Image
import io
import time
from streamlit_drawable_canvas import st_canvas
# Helper functions
def np_to_pil(np_img_bgr):
if len(np_img_bgr.shape) == 2:
return Image.fromarray(np_img_bgr)
else:
return Image.fromarray(np_img_bgr[..., ::-1])
def pil_to_np(pil_img):
np_img_rgb = np.array(pil_img)
if np_img_rgb.shape[-1] == 4:
np_img_rgb = np_img_rgb[..., :3]
return np_img_rgb[..., ::-1]
def download_button_img(np_img_bgr, label, filename):
img = np_to_pil(np_img_bgr)
buf = io.BytesIO()
img.save(buf, format="PNG")
st.download_button(label, data=buf.getvalue(), file_name=filename, mime="image/png")
# Set page config
st.set_page_config(page_title="Image Restoration App", layout="wide")
st.title("Image Restoration App")
# Upload section
st.sidebar.title("Upload Image")
uploaded_file = st.sidebar.file_uploader("Choose an image", type=["png", "jpg", "jpeg"])
if "orig_image" not in st.session_state:
st.session_state.orig_image = None
if "current_image" not in st.session_state:
st.session_state.current_image = None
if "inpaint_result" not in st.session_state:
st.session_state.inpaint_result = None
if "canvas_result" not in st.session_state:
st.session_state.canvas_result = None
if uploaded_file:
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
image = cv2.imdecode(file_bytes, 1)
st.session_state.orig_image = image
st.session_state.current_image = image.copy()
st.session_state.inpaint_result = None
if st.session_state.orig_image is None:
st.info("Upload an image to get started.")
st.stop()
# Tabs
tabs = st.tabs(["Filters", "Inpainting", "Compare"])
# FILTERS TAB
with tabs[0]:
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("Filters")
filter_type = st.selectbox(
"Choose filter:",
["None", "Gaussian", "Median", "Bilateral", "Brightness/Contrast", "Grayscale"],
key="filter",
)
if filter_type == "Gaussian":
ksize = st.slider("Kernel Size", 1, 31, 5, step=2, key="gauss_ksize")
sigma = st.slider("Sigma X", 0.0, 10.0, 2.0, key="gauss_sigma")
elif filter_type == "Median":
ksize = st.slider("Kernel Size", 1, 31, 5, step=2, key="median_ksize")
elif filter_type == "Bilateral":
d = st.slider("Diameter", 1, 30, 9, key="bilateral_d")
sigmaColor = st.slider("Sigma Color", 1, 150, 75, key="bilateral_color")
sigmaSpace = st.slider("Sigma Space", 1, 150, 75, key="bilateral_space")
elif filter_type == "Brightness/Contrast":
brightness = st.slider("Brightness", -100, 100, 0, key="brightness")
contrast = st.slider("Contrast", -100, 100, 0, key="contrast")
if st.button("Apply Filter", key="apply_filter"):
img = st.session_state.current_image.copy()
if filter_type == "Gaussian":
img = cv2.GaussianBlur(img, (ksize, ksize), sigma)
elif filter_type == "Median":
img = cv2.medianBlur(img, ksize)
elif filter_type == "Bilateral":
img = cv2.bilateralFilter(img, d, sigmaColor, sigmaSpace)
elif filter_type == "Brightness/Contrast":
img = cv2.convertScaleAbs(img, alpha=1 + contrast / 100.0, beta=brightness)
elif filter_type == "Grayscale":
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
st.session_state.current_image = img
st.session_state.inpaint_result = None
if st.button("Reset Image", key="reset_filter"):
st.session_state.current_image = st.session_state.orig_image.copy()
st.session_state.inpaint_result = None
with col2:
st.subheader("Image Preview")
img = st.session_state.current_image
st.image(img if len(img.shape) == 2 else img[..., ::-1], use_container_width=True)
# INPAINTING TAB
with tabs[1]:
col1, col2, col3 = st.columns([1, 1.5, 1.5])
with col1:
st.subheader("Inpainting Settings")
stroke_width = st.slider("Stroke Width", 1, 25, 5, key="stroke")
method = st.selectbox("Inpainting Method", ["Telea", "NS"], key="inpaint_method")
if st.button("Apply Inpaint", key="apply_inpaint"):
canvas = st.session_state.get("canvas_result")
if canvas and canvas.image_data is not None:
mask_rgba = canvas.image_data
if mask_rgba.shape[-1] == 4:
mask = mask_rgba[..., 3]
h, w = st.session_state.current_image.shape[:2]
mask = cv2.resize(mask, (w, h))
mask = (mask > 0).astype(np.uint8) * 255
flag = cv2.INPAINT_TELEA if method == "Telea" else cv2.INPAINT_NS
result = cv2.inpaint(st.session_state.current_image, mask, 3, flag)
st.session_state.inpaint_result = result
if st.button("Reset to Original", key="reset_inpaint"):
st.session_state.current_image = st.session_state.orig_image.copy()
st.session_state.inpaint_result = None
st.markdown("---")
if st.button("Reset Canvas"):
st.session_state.canvas_key = f"canvas_{int(time.time())}"
with col2:
st.subheader("Draw Mask")
h, w = st.session_state.current_image.shape[:2]
max_width = 500
scale = min(1.0, max_width / w)
canvas_w, canvas_h = int(w * scale), int(h * scale)
show_mask = st.checkbox("Show Mask Preview", key="show_mask")
if "canvas_key" not in st.session_state:
st.session_state.canvas_key = "canvas"
if not show_mask:
pil_bg = np_to_pil(st.session_state.current_image).resize((canvas_w, canvas_h))
canvas = st_canvas(
fill_color="white",
stroke_width=stroke_width,
stroke_color="black",
background_image=pil_bg,
update_streamlit=True,
height=canvas_h,
width=canvas_w,
drawing_mode="freedraw",
key=st.session_state.canvas_key,
)
st.session_state.canvas_result = canvas
else:
canvas = st.session_state.get("canvas_result")
if canvas and canvas.image_data is not None:
mask = canvas.image_data[..., 3] if canvas.image_data.shape[-1] == 4 else None
if mask is not None:
mask = cv2.resize(mask, (w, h))
mask = (mask > 0).astype(np.uint8) * 255
st.image(mask, caption="Inpainting Mask", use_container_width=True)
with col3:
st.subheader("Inpainting Result")
result = st.session_state.inpaint_result
if result is not None:
st.image(result[..., ::-1], use_container_width=True)
download_button_img(result, "Download Inpainted Image", "inpainted_result.png")
else:
st.info("Draw a mask and apply inpainting to see result.")
# COMPARE TAB
with tabs[2]:
col1, col2 = st.columns(2)
with col1:
st.subheader("Original Image")
orig = st.session_state.orig_image
st.image(orig[..., ::-1], use_container_width=True)
download_button_img(orig, "Download Original", "original.png")
with col2:
st.subheader("Processed Image")
current = (
st.session_state.inpaint_result
if st.session_state.inpaint_result is not None
else st.session_state.current_image
)
st.image(current if len(current.shape) == 2 else current[..., ::-1], use_container_width=True)
download_button_img(current, "Download Current", "current.png")