room_cleaner / app.py
blanchon's picture
Add examples with mask
50704a0
raw
history blame
10.5 kB
from typing import cast
from comfydeploy import ComfyDeploy
import asyncio
import os
import gradio as gr
from gradio.components.image_editor import EditorValue
from PIL import Image
import requests
import dotenv
from gradio_imageslider import ImageSlider
from io import BytesIO
import base64
import glob
import numpy as np
dotenv.load_dotenv()
API_KEY = os.environ.get("API_KEY")
DEPLOYMENT_ID = os.environ.get("DEPLOYMENT_ID", "DEPLOYMENT_ID_NOT_SET")
if not API_KEY:
raise ValueError(
"Please set API_KEY and DEPLOYMENT_ID in your environment variables"
)
if DEPLOYMENT_ID == "DEPLOYMENT_ID_NOT_SET":
raise ValueError("Please set DEPLOYMENT_ID in your environment variables")
client = ComfyDeploy(bearer_auth=API_KEY)
def get_base64_from_image(image: Image.Image) -> str:
buffered: BytesIO = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
async def process_image(
image: Image.Image | str | None,
mask: Image.Image | str | None,
progress: gr.Progress = gr.Progress(),
) -> Image.Image | None:
progress(0, desc="Starting...")
if image is None or mask is None:
return None
if isinstance(mask, str):
mask = Image.open(mask)
if isinstance(image, str):
image = Image.open(image)
image_base64 = get_base64_from_image(image)
mask_base64 = get_base64_from_image(mask)
# Prepare inputs
inputs: dict = {
"image": f"data:image/png;base64,{image_base64}",
"mask": f"data:image/png;base64,{mask_base64}",
}
# Call ComfyDeploy API
try:
result = client.run.create(
request={"deployment_id": DEPLOYMENT_ID, "inputs": inputs}
)
if result and result.object:
run_id: str = result.object.run_id
progress(0, desc="Starting processing...")
# Wait for the result
while True:
run_result = client.run.get(run_id=run_id)
if not run_result.object:
continue
progress_value = (
run_result.object.progress
if run_result.object.progress is not None
else 0
)
status = (
run_result.object.live_status
if run_result.object.live_status is not None
else "Cold starting..."
)
progress(progress_value, desc=f"Status: {status}")
if run_result.object.status == "success":
for output in run_result.object.outputs or []:
if output.data and output.data.images:
image_url: str = output.data.images[0].url
# Download and return both the original and processed images
response: requests.Response = requests.get(image_url)
processed_image: Image.Image = Image.open(
BytesIO(response.content)
)
return processed_image
return None
elif run_result.object.status == "failed":
print("Processing failed")
return None
await asyncio.sleep(2) # Wait for 2 seconds before checking again
except Exception as e:
print(f"Error: {e}")
return None
def make_example(background_path: str, mask_path: str) -> EditorValue:
example1_background = np.array(Image.open(background_path))
example1_mask_only = np.array(Image.open(mask_path))[:, :, -1]
example1_layers = np.zeros(
(example1_background.shape[0], example1_background.shape[1], 4), dtype=np.uint8
)
example1_layers[:, :, 3] = example1_mask_only
example1_composite = np.zeros(
(example1_background.shape[0], example1_background.shape[1], 4), dtype=np.uint8
)
example1_composite[:, :, :3] = example1_background
example1_composite[:, :, 3] = np.where(example1_mask_only == 255, 0, 255)
return {
"background": example1_background,
"layers": [example1_layers],
"composite": example1_composite,
}
def resize_image(img: Image.Image, min_side_length: int = 768) -> Image.Image:
if img.width <= min_side_length and img.height <= min_side_length:
return img
aspect_ratio = img.width / img.height
if img.width < img.height:
new_height = int(min_side_length / aspect_ratio)
return img.resize((min_side_length, new_height))
new_width = int(min_side_length * aspect_ratio)
return img.resize((new_width, min_side_length))
async def run_async(
image_and_mask: EditorValue | None,
progress: gr.Progress = gr.Progress(),
) -> tuple[Image.Image, Image.Image] | None:
if not image_and_mask:
gr.Info("Please upload an image and draw a mask")
return None
image_np = image_and_mask["background"]
image_np = cast(np.ndarray, image_np)
# If the image is empty, return None
if np.sum(image_np) == 0:
gr.Info("Please upload an image")
return None
alpha_channel = image_and_mask["layers"][0]
alpha_channel = cast(np.ndarray, alpha_channel)
mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8)
# if mask_np is empty, return None
if np.sum(mask_np) == 0:
gr.Info("Please mark the areas you want to remove")
return None
mask = Image.fromarray(mask_np)
mask = resize_image(mask)
image = Image.fromarray(image_np)
image = resize_image(image)
output = await process_image(
image, # type: ignore
mask, # type: ignore
progress,
)
if output is None:
gr.Info("Processing failed")
return None
return image, output
def run_sync(*args):
return asyncio.run(run_async(*args))
with gr.Blocks() as demo:
gr.HTML("""
<div style="text-align:center;">
<h1>🧹 Room Cleaner</h1>
<div>
<p>Upload an image and use the pencil tool (✏️ icon at the bottom) to <b>mark the areas you want to remove</b>.</p>
<p>For best results, include the shadows and reflections of the objects you want to remove.</p>
<p>You can remove multiple objects at once.</p>
<p>If you forget to mask some parts of your object, it's likely that the model will reconstruct them.</p>
<br>
<video width="640" height="360" controls style="margin: 0 auto; border-radius: 10px;">
<source src="https://dropshare.blanchon.xyz/public/dropshare/room_cleaner_demo.mp4" type="video/mp4">
</video>
<br>
<p>Finally, click on the <b>"Run"</b> button to process the image.</p>
<p>Wait for the processing to complete and compare the original and processed images using the slider.</p>
<p>⚠️ Note that the images are compressed to the workloads of the demo. </p>
</div>
</div>
""")
with gr.Row():
with gr.Column():
# The image overflow, fix
image_and_mask = gr.ImageMask(
label="Image and Mask",
layers=False,
show_fullscreen_button=False,
sources=["upload"],
show_download_button=False,
interactive=True,
height="full",
width="full",
brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"),
transforms=[],
)
with gr.Column():
image_slider = ImageSlider(
label="Result",
interactive=False,
)
process_btn = gr.ClearButton(
value="Run",
variant="primary",
size="lg",
components=[image_slider],
)
process_btn.click(
fn=run_sync,
inputs=[
image_and_mask,
],
outputs=[image_slider],
api_name=False,
)
# Build examples
images_examples = glob.glob("examples/*.jpg")
mask_examples = [
img.replace(".jpg", "") + "_mask_only.png" for img in images_examples
]
output_examples = [
img.replace(".jpg", "") + "results.png" for img in images_examples
]
# examples = [
# [
# img,
# mask,
# (img, out),
# ]
# for img, mask, out in zip(images_examples, mask_examples, output_examples)
# ]
example1 = make_example(images_examples[0], mask_examples[0])
example2 = make_example(images_examples[1], mask_examples[1])
example3 = make_example(images_examples[2], mask_examples[2])
example4 = make_example(images_examples[3], mask_examples[3])
examples = [
[
example1,
# ("./examples/ex1.jpg", "./examples/ex1_result.png")
(
"https://dropshare.blanchon.xyz/public/dropshare/ex1.jpg",
"https://dropshare.blanchon.xyz/public/dropshare/ex1_results.png",
),
],
[
example2,
# ("./examples/ex2.jpg", "./examples/ex2_result.png")
(
"https://dropshare.blanchon.xyz/public/dropshare/ex2.jpg",
"https://dropshare.blanchon.xyz/public/dropshare/ex2_result.png",
),
],
[
example3,
# ("./examples/ex3.jpg", "./examples/ex3_result.png")
(
"https://dropshare.blanchon.xyz/public/dropshare/ex3.jpg",
"https://dropshare.blanchon.xyz/public/dropshare/ex3_result.png",
),
],
[
example4,
# ("./examples/ex4.jpg", "./examples/ex4_result.png")
(
"https://dropshare.blanchon.xyz/public/dropshare/ex4.jpg",
"https://dropshare.blanchon.xyz/public/dropshare/ex4_result.png",
),
],
]
# Update the gr.Examples call
gr.Examples(
examples=examples,
inputs=[
image_and_mask,
image_slider,
],
api_name=False,
)
if __name__ == "__main__":
demo.launch(debug=True, share=True)