import os import numpy as np import streamlit as st from PIL import Image, ImageDraw, ImageFilter import numpy as np import torch from streamlit_js_eval import streamlit_js_eval # Import the custom component for image coordinates from streamlit_image_coordinates import streamlit_image_coordinates # Import diffusers pipeline for Stable Diffusion inpainting from diffusers import StableDiffusionInpaintPipeline # Ultralytics provides the FastSAM model class from ultralytics import FastSAM # Set page config for a better mobile experience st.set_page_config(page_title="Inpainting Demo", layout="centered") page_width = streamlit_js_eval(js_expressions='window.innerWidth', key='WIDTH', want_output = True,) # Define model paths or IDs for easy switching in the future FASTSAM_CHECKPOINT = "FastSAM-x.pt" # file name of the FastSAM model weights SD_MODEL_ID = "runwayml/stable-diffusion-inpainting" # HF Hub model for SD Inpainting v1.5 # Helper function: center crop and resize to 768x512 (landscape) def crop_resize_image(image, target_width=480, target_height=640): desired_ratio = target_width / target_height # 768/512 = 1.5 width, height = image.size current_ratio = width / height # Crop horizontally if image is too wide if current_ratio > desired_ratio: new_width = int(height * desired_ratio) left = (width - new_width) // 2 right = left + new_width image = image.crop((left, 0, right, height)) # Crop vertically if image is too tall elif current_ratio < desired_ratio: new_height = int(width / desired_ratio) top = (height - new_height) // 2 bottom = top + new_height image = image.crop((0, top, width, bottom)) return image.resize((target_width, target_height)) # Ensure FastSAM model weights are available (download if not present) if not os.path.exists(FASTSAM_CHECKPOINT): # Download FastSAM weights (if not already in the repo) # Here we use the official Ultralytics release URL for FastSAM-x (68MB). import requests fastsam_url = "https://github.com/ultralytics/assets/releases/download/v8.2.0/FastSAM-x.pt" # st.write("Downloading FastSAM model weights...") resp = requests.get(fastsam_url) open(FASTSAM_CHECKPOINT, "wb").write(resp.content) # Load models with caching to avoid reloading on each interaction @st.cache_resource def load_models(): # Load FastSAM model fastsam_model = FastSAM(FASTSAM_CHECKPOINT) # load the checkpoint # Move FastSAM to GPU if available # (Ultralytics will internally handle device assignment when calling the model) # Load Stable Diffusion inpainting pipeline sd_pipe = StableDiffusionInpaintPipeline.from_pretrained( SD_MODEL_ID, torch_dtype=None # we'll let diffusers choose float16 if GPU is available ) # Move pipeline to GPU for faster inference, if a GPU is available sd_pipe = sd_pipe.to("cuda" if torch.cuda.is_available() else "cpu") # (Optional) Enable memory optimizations sd_pipe.enable_attention_slicing() # improve memory usage return fastsam_model, sd_pipe # Initialize the models (this will run only once thanks to caching) fastsam_model, sd_pipe = load_models() # Ensure we have a state for removing_dots if "is_removing_dot" not in st.session_state: st.session_state.is_removing_dot = False # Title st.subheader("InteractiveInpainting") # Camera input widget (opens device camera on mobile/desktop) # picture = st.camera_input("Take a picture") # picture = Image.new(mode="RGB", size=(512, 512), color=(153, 153, 255)) # Capture image from camera and process it if "img" not in st.session_state: enable = st.checkbox("Enable camera") picture = st.camera_input("Take a picture", disabled=not enable) if picture is not None: img = Image.open(picture) img = crop_resize_image(img, target_width=480, target_height=640) st.session_state.img = img # Reset coordinates list on new capture st.session_state.coords_list = [] st.rerun() else: img = st.session_state.img # Initialize the coordinates list if it doesn't exist. if "coords_list" not in st.session_state: st.session_state.coords_list = [] # --- Compute Segmentation Overlay --- # If any points have been stored, run segmentation with FastSAM. if st.session_state.coords_list: points = [[int(pt["x"]), int(pt["y"])] for pt in st.session_state.coords_list] labels = [1] * len(points) results = fastsam_model(img, points=points, labels=labels) # Assume results[0].masks.data is a tensor with shape (N, H, W) masks_tensor = results[0].masks.data masks = masks_tensor.cpu().numpy() if masks.ndim == 3 and masks.shape[0] > 0: # Combine masks (logical OR via max) combined_mask = np.max(masks, axis=0) combined_mask_img = Image.fromarray((combined_mask * 255).astype(np.uint8)) # Resize the mask to ensure it matches the base image size combined_mask_img = combined_mask_img.resize(img.size, Image.NEAREST) # Create a red overlay with transparency overlay = Image.new("RGBA", img.size, (255, 0, 0, 100)) base = img.convert("RGBA") mask_alpha = combined_mask_img.point(lambda p: 80 if p > 0 else 0) overlay.putalpha(mask_alpha) seg_overlay = Image.alpha_composite(base, overlay) else: seg_overlay = img.copy() else: seg_overlay = img.copy() # --- Draw Red Dots on Top --- final_img = seg_overlay.copy() draw = ImageDraw.Draw(final_img) for pt in st.session_state.coords_list: cx, cy = int(pt["x"]), int(pt["y"]) draw.ellipse((cx - 5, cy - 5, cx + 5, cy + 5), fill="red") # Get the original width from the image stored in session_state. original_width = st.session_state.img.width # e.g. 480 from crop_resize_image # Compute the scaling factor. scale_factor = original_width / page_width # Use the interactive component as the display canvas, showing the image with all dots. new_coord = streamlit_image_coordinates(final_img, key="click_img", use_column_width="always") # Remap from displayed coordinate to original coordinate if new_coord: new_coord = { "x": new_coord["x"] * scale_factor, "y": new_coord["y"] * scale_factor } # If a new coordinate is received and it's not already in our list, add it and force a rerun. if new_coord and new_coord not in st.session_state.coords_list and not st.session_state.is_removing_dot: is_close = False for coord in st.session_state.coords_list: existing = np.array([coord["x"], coord["y"]]) new = np.array([new_coord["x"], new_coord["y"]]) if np.linalg.norm(existing - new) < 10: is_close = True break if is_close: st.session_state.coords_list.remove(coord) st.session_state.is_removing_dot = True else: st.session_state.coords_list.append(new_coord) st.rerun() else: st.session_state.is_removing_dot = False st.write("Stored coordinates:", st.session_state.coords_list) # --- 4) INPAINTING LOGIC --- prompt = st.text_input("Prompt for inpainting (describe what should replace the selected area):") # If there's a prompt and we have at least one mask from the combined points, do inpainting if prompt and combined_mask_img is not None: combined_mask_img = combined_mask_img.convert("L") # Dilate the mask: using a MaxFilter with a size (e.g. 5) dilated_mask = combined_mask_img.filter(ImageFilter.MaxFilter(5)) # Blur the mask edges: adjust radius as needed (e.g. radius=3) blurred_mask = dilated_mask.filter(ImageFilter.GaussianBlur(radius=3)) if st.button("Run Inpainting"): with st.spinner("Inpainting..."): # Run Stable Diffusion Inpainting on the entire combined mask inpainted_img = sd_pipe( prompt=prompt, image=img, mask_image=combined_mask_img, width=img.width, height=img.height, guidance_scale=8, # How strongly to follow the prompt num_inference_steps=50 ).images[0] # Update the session image to the newly inpainted result st.session_state.img = inpainted_img # Optionally reset the points or keep them st.session_state.coords_list = [] st.rerun()