Spaces:
Running
Running
from typing import cast | |
from comfydeploy import ComfyDeploy | |
import asyncio | |
import os | |
import gradio as gr | |
from gradio.components.image_editor import EditorValue | |
from PIL import Image | |
import requests | |
import dotenv | |
from gradio_imageslider import ImageSlider | |
from io import BytesIO | |
import base64 | |
import glob | |
import numpy as np | |
dotenv.load_dotenv() | |
API_KEY = os.environ.get("API_KEY") | |
DEPLOYMENT_ID = os.environ.get("DEPLOYMENT_ID", "DEPLOYMENT_ID_NOT_SET") | |
if not API_KEY: | |
raise ValueError( | |
"Please set API_KEY and DEPLOYMENT_ID in your environment variables" | |
) | |
if DEPLOYMENT_ID == "DEPLOYMENT_ID_NOT_SET": | |
raise ValueError("Please set DEPLOYMENT_ID in your environment variables") | |
client = ComfyDeploy(bearer_auth=API_KEY) | |
def get_base64_from_image(image: Image.Image) -> str: | |
buffered: BytesIO = BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
async def process_image( | |
image: Image.Image | str | None, | |
mask: Image.Image | str | None, | |
progress: gr.Progress = gr.Progress(), | |
) -> Image.Image | None: | |
progress(0, desc="Starting...") | |
if image is None or mask is None: | |
return None | |
if isinstance(mask, str): | |
mask = Image.open(mask) | |
if isinstance(image, str): | |
image = Image.open(image) | |
image_base64 = get_base64_from_image(image) | |
mask_base64 = get_base64_from_image(mask) | |
# Prepare inputs | |
inputs: dict = { | |
"image": f"data:image/png;base64,{image_base64}", | |
"mask": f"data:image/png;base64,{mask_base64}", | |
} | |
# Call ComfyDeploy API | |
try: | |
result = client.run.create( | |
request={"deployment_id": DEPLOYMENT_ID, "inputs": inputs} | |
) | |
if result and result.object: | |
run_id: str = result.object.run_id | |
progress(0, desc="Starting processing...") | |
# Wait for the result | |
while True: | |
run_result = client.run.get(run_id=run_id) | |
if not run_result.object: | |
continue | |
progress_value = ( | |
run_result.object.progress | |
if run_result.object.progress is not None | |
else 0 | |
) | |
status = ( | |
run_result.object.live_status | |
if run_result.object.live_status is not None | |
else "Cold starting..." | |
) | |
progress(progress_value, desc=f"Status: {status}") | |
if run_result.object.status == "success": | |
for output in run_result.object.outputs or []: | |
if output.data and output.data.images: | |
image_url: str = output.data.images[0].url | |
# Download and return both the original and processed images | |
response: requests.Response = requests.get(image_url) | |
processed_image: Image.Image = Image.open( | |
BytesIO(response.content) | |
) | |
return processed_image | |
return None | |
elif run_result.object.status == "failed": | |
print("Processing failed") | |
return None | |
await asyncio.sleep(2) # Wait for 2 seconds before checking again | |
except Exception as e: | |
print(f"Error: {e}") | |
return None | |
def make_example(background_path: str, mask_path: str) -> EditorValue: | |
example1_background = np.array(Image.open(background_path)) | |
example1_mask_only = np.array(Image.open(mask_path))[:, :, -1] | |
example1_layers = np.zeros( | |
(example1_background.shape[0], example1_background.shape[1], 4), dtype=np.uint8 | |
) | |
example1_layers[:, :, 3] = example1_mask_only | |
example1_composite = np.zeros( | |
(example1_background.shape[0], example1_background.shape[1], 4), dtype=np.uint8 | |
) | |
example1_composite[:, :, :3] = example1_background | |
example1_composite[:, :, 3] = np.where(example1_mask_only == 255, 0, 255) | |
return { | |
"background": example1_background, | |
"layers": [example1_layers], | |
"composite": example1_composite, | |
} | |
def resize_image(img: Image.Image, min_side_length: int = 768) -> Image.Image: | |
if img.width <= min_side_length and img.height <= min_side_length: | |
return img | |
aspect_ratio = img.width / img.height | |
if img.width < img.height: | |
new_height = int(min_side_length / aspect_ratio) | |
return img.resize((min_side_length, new_height)) | |
new_width = int(min_side_length * aspect_ratio) | |
return img.resize((new_width, min_side_length)) | |
async def run_async( | |
image_and_mask: EditorValue | None, | |
progress: gr.Progress = gr.Progress(), | |
) -> tuple[Image.Image, Image.Image] | None: | |
if not image_and_mask: | |
gr.Info("Please upload an image and draw a mask") | |
return None | |
image_np = image_and_mask["background"] | |
image_np = cast(np.ndarray, image_np) | |
# If the image is empty, return None | |
if np.sum(image_np) == 0: | |
gr.Info("Please upload an image") | |
return None | |
alpha_channel = image_and_mask["layers"][0] | |
alpha_channel = cast(np.ndarray, alpha_channel) | |
mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) | |
# if mask_np is empty, return None | |
if np.sum(mask_np) == 0: | |
gr.Info("Please mark the areas you want to remove") | |
return None | |
mask = Image.fromarray(mask_np) | |
mask = resize_image(mask) | |
image = Image.fromarray(image_np) | |
image = resize_image(image) | |
output = await process_image( | |
image, # type: ignore | |
mask, # type: ignore | |
progress, | |
) | |
if output is None: | |
gr.Info("Processing failed") | |
return None | |
return image, output | |
def run_sync(*args): | |
return asyncio.run(run_async(*args)) | |
with gr.Blocks() as demo: | |
gr.HTML(""" | |
<div style="text-align:center;"> | |
<h1>🧹 Room Cleaner</h1> | |
<div> | |
<p>Upload an image and use the pencil tool (✏️ icon at the bottom) to <b>mark the areas you want to remove</b>.</p> | |
<p>For best results, include the shadows and reflections of the objects you want to remove.</p> | |
<p>You can remove multiple objects at once.</p> | |
<p>If you forget to mask some parts of your object, it's likely that the model will reconstruct them.</p> | |
<br> | |
<video width="640" height="360" controls style="margin: 0 auto; border-radius: 10px;"> | |
<source src="https://dropshare.blanchon.xyz/public/dropshare/room_cleaner_demo.mp4" type="video/mp4"> | |
</video> | |
<br> | |
<p>Finally, click on the <b>"Run"</b> button to process the image.</p> | |
<p>Wait for the processing to complete and compare the original and processed images using the slider.</p> | |
<p>⚠️ Note that the images are compressed to the workloads of the demo. </p> | |
</div> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
# The image overflow, fix | |
image_and_mask = gr.ImageMask( | |
label="Image and Mask", | |
layers=False, | |
show_fullscreen_button=False, | |
sources=["upload"], | |
show_download_button=False, | |
interactive=True, | |
height="full", | |
width="full", | |
brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"), | |
transforms=[], | |
) | |
with gr.Column(): | |
image_slider = ImageSlider( | |
label="Result", | |
interactive=False, | |
) | |
process_btn = gr.ClearButton( | |
value="Run", | |
variant="primary", | |
size="lg", | |
components=[image_slider], | |
) | |
process_btn.click( | |
fn=run_sync, | |
inputs=[ | |
image_and_mask, | |
], | |
outputs=[image_slider], | |
api_name=False, | |
) | |
# Build examples | |
images_examples = glob.glob("examples/*.jpg") | |
mask_examples = [ | |
img.replace(".jpg", "") + "_mask_only.png" for img in images_examples | |
] | |
output_examples = [ | |
img.replace(".jpg", "") + "results.png" for img in images_examples | |
] | |
# examples = [ | |
# [ | |
# img, | |
# mask, | |
# (img, out), | |
# ] | |
# for img, mask, out in zip(images_examples, mask_examples, output_examples) | |
# ] | |
example1 = make_example(images_examples[0], mask_examples[0]) | |
example2 = make_example(images_examples[1], mask_examples[1]) | |
example3 = make_example(images_examples[2], mask_examples[2]) | |
example4 = make_example(images_examples[3], mask_examples[3]) | |
examples = [ | |
[ | |
example1, | |
# ("./examples/ex1.jpg", "./examples/ex1_result.png") | |
( | |
"https://dropshare.blanchon.xyz/public/dropshare/ex1.jpg", | |
"https://dropshare.blanchon.xyz/public/dropshare/ex1_results.png", | |
), | |
], | |
[ | |
example2, | |
# ("./examples/ex2.jpg", "./examples/ex2_result.png") | |
( | |
"https://dropshare.blanchon.xyz/public/dropshare/ex2.jpg", | |
"https://dropshare.blanchon.xyz/public/dropshare/ex2_result.png", | |
), | |
], | |
[ | |
example3, | |
# ("./examples/ex3.jpg", "./examples/ex3_result.png") | |
( | |
"https://dropshare.blanchon.xyz/public/dropshare/ex3.jpg", | |
"https://dropshare.blanchon.xyz/public/dropshare/ex3_result.png", | |
), | |
], | |
[ | |
example4, | |
# ("./examples/ex4.jpg", "./examples/ex4_result.png") | |
( | |
"https://dropshare.blanchon.xyz/public/dropshare/ex4.jpg", | |
"https://dropshare.blanchon.xyz/public/dropshare/ex4_result.png", | |
), | |
], | |
] | |
# Update the gr.Examples call | |
gr.Examples( | |
examples=examples, | |
inputs=[ | |
image_and_mask, | |
image_slider, | |
], | |
api_name=False, | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True, share=True) | |