| | import os |
| | import numpy as np |
| | import streamlit as st |
| | from PIL import Image, ImageDraw, ImageFilter |
| | import numpy as np |
| | import torch |
| | from streamlit_js_eval import streamlit_js_eval |
| |
|
| |
|
| |
|
| | |
| | from streamlit_image_coordinates import streamlit_image_coordinates |
| |
|
| | |
| | from diffusers import StableDiffusionInpaintPipeline |
| |
|
| |
|
| | |
| | from ultralytics import FastSAM |
| |
|
| | |
| | st.set_page_config(page_title="Inpainting Demo", layout="centered") |
| |
|
| |
|
| | page_width = streamlit_js_eval(js_expressions='window.innerWidth', key='WIDTH', want_output = True,) |
| |
|
| |
|
| | |
| | FASTSAM_CHECKPOINT = "FastSAM-x.pt" |
| | SD_MODEL_ID = "runwayml/stable-diffusion-inpainting" |
| |
|
| | |
| | def crop_resize_image(image, target_width=480, target_height=640): |
| | desired_ratio = target_width / target_height |
| | width, height = image.size |
| | current_ratio = width / height |
| |
|
| | |
| | if current_ratio > desired_ratio: |
| | new_width = int(height * desired_ratio) |
| | left = (width - new_width) // 2 |
| | right = left + new_width |
| | image = image.crop((left, 0, right, height)) |
| | |
| | elif current_ratio < desired_ratio: |
| | new_height = int(width / desired_ratio) |
| | top = (height - new_height) // 2 |
| | bottom = top + new_height |
| | image = image.crop((0, top, width, bottom)) |
| | |
| | return image.resize((target_width, target_height)) |
| |
|
| | |
| | if not os.path.exists(FASTSAM_CHECKPOINT): |
| | |
| | |
| | import requests |
| | fastsam_url = "https://github.com/ultralytics/assets/releases/download/v8.2.0/FastSAM-x.pt" |
| | |
| | resp = requests.get(fastsam_url) |
| | open(FASTSAM_CHECKPOINT, "wb").write(resp.content) |
| |
|
| | |
| | @st.cache_resource |
| | def load_models(): |
| | |
| | fastsam_model = FastSAM(FASTSAM_CHECKPOINT) |
| | |
| | |
| | |
| | |
| | sd_pipe = StableDiffusionInpaintPipeline.from_pretrained( |
| | SD_MODEL_ID, |
| | torch_dtype=None |
| | ) |
| | |
| | sd_pipe = sd_pipe.to("cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | sd_pipe.enable_attention_slicing() |
| | return fastsam_model, sd_pipe |
| |
|
| | |
| | fastsam_model, sd_pipe = load_models() |
| |
|
| | |
| | if "is_removing_dot" not in st.session_state: |
| | st.session_state.is_removing_dot = False |
| |
|
| | |
| | st.subheader("InteractiveInpainting") |
| |
|
| |
|
| |
|
| | |
| | |
| |
|
| |
|
| | |
| | |
| |
|
| |
|
| |
|
| | if "img" not in st.session_state: |
| | enable = st.checkbox("Enable camera") |
| | picture = st.camera_input("Take a picture", disabled=not enable) |
| | if picture is not None: |
| | img = Image.open(picture) |
| | img = crop_resize_image(img, target_width=480, target_height=640) |
| | st.session_state.img = img |
| | |
| | st.session_state.coords_list = [] |
| | st.rerun() |
| |
|
| | else: |
| | img = st.session_state.img |
| |
|
| | |
| | if "coords_list" not in st.session_state: |
| | st.session_state.coords_list = [] |
| |
|
| | |
| | |
| | if st.session_state.coords_list: |
| | points = [[int(pt["x"]), int(pt["y"])] for pt in st.session_state.coords_list] |
| | labels = [1] * len(points) |
| | results = fastsam_model(img, points=points, labels=labels) |
| | |
| | masks_tensor = results[0].masks.data |
| | masks = masks_tensor.cpu().numpy() |
| | if masks.ndim == 3 and masks.shape[0] > 0: |
| | |
| | combined_mask = np.max(masks, axis=0) |
| | combined_mask_img = Image.fromarray((combined_mask * 255).astype(np.uint8)) |
| | |
| | combined_mask_img = combined_mask_img.resize(img.size, Image.NEAREST) |
| | |
| | overlay = Image.new("RGBA", img.size, (255, 0, 0, 100)) |
| | base = img.convert("RGBA") |
| | mask_alpha = combined_mask_img.point(lambda p: 80 if p > 0 else 0) |
| | overlay.putalpha(mask_alpha) |
| |
|
| | seg_overlay = Image.alpha_composite(base, overlay) |
| | else: |
| | seg_overlay = img.copy() |
| | else: |
| | seg_overlay = img.copy() |
| |
|
| | |
| | final_img = seg_overlay.copy() |
| | draw = ImageDraw.Draw(final_img) |
| | for pt in st.session_state.coords_list: |
| | cx, cy = int(pt["x"]), int(pt["y"]) |
| | draw.ellipse((cx - 5, cy - 5, cx + 5, cy + 5), fill="red") |
| |
|
| | |
| | |
| | original_width = st.session_state.img.width |
| |
|
| | |
| | scale_factor = original_width / page_width |
| | |
| | new_coord = streamlit_image_coordinates(final_img, key="click_img", use_column_width="always") |
| |
|
| | |
| | if new_coord: |
| | new_coord = { |
| | "x": new_coord["x"] * scale_factor, |
| | "y": new_coord["y"] * scale_factor |
| | } |
| |
|
| | |
| | if new_coord and new_coord not in st.session_state.coords_list and not st.session_state.is_removing_dot: |
| | is_close = False |
| | for coord in st.session_state.coords_list: |
| | existing = np.array([coord["x"], coord["y"]]) |
| | new = np.array([new_coord["x"], new_coord["y"]]) |
| | if np.linalg.norm(existing - new) < 10: |
| | is_close = True |
| | break |
| | if is_close: |
| | st.session_state.coords_list.remove(coord) |
| | st.session_state.is_removing_dot = True |
| | else: |
| | st.session_state.coords_list.append(new_coord) |
| | st.rerun() |
| | else: |
| | st.session_state.is_removing_dot = False |
| |
|
| | st.write("Stored coordinates:", st.session_state.coords_list) |
| |
|
| |
|
| | |
| | |
| | prompt = st.text_input("Prompt for inpainting (describe what should replace the selected area):") |
| |
|
| | |
| | if prompt and combined_mask_img is not None: |
| |
|
| | combined_mask_img = combined_mask_img.convert("L") |
| |
|
| | |
| | dilated_mask = combined_mask_img.filter(ImageFilter.MaxFilter(5)) |
| |
|
| | |
| | blurred_mask = dilated_mask.filter(ImageFilter.GaussianBlur(radius=3)) |
| | if st.button("Run Inpainting"): |
| | with st.spinner("Inpainting..."): |
| | |
| | inpainted_img = sd_pipe( |
| | prompt=prompt, |
| | image=img, |
| | mask_image=combined_mask_img, |
| | width=img.width, |
| | height=img.height, |
| | guidance_scale=8, |
| | num_inference_steps=50 |
| | ).images[0] |
| |
|
| | |
| | st.session_state.img = inpainted_img |
| | |
| | st.session_state.coords_list = [] |
| | st.rerun() |