Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import tempfile | |
import os | |
# Function to create the binary mask from the ImageEditor's output | |
def create_binary_mask(im_dict): | |
if im_dict is None or im_dict["background"] is None: | |
print("No background image found.") | |
# Return a small blank placeholder and None for the file path | |
blank_preview = np.zeros((768, 1024), dtype=np.uint8) | |
return blank_preview, None | |
background_img = im_dict["background"] | |
h, w, _ = background_img.shape # Get original dimensions (Height, Width, Channels) | |
print(f"Original image dimensions: H={h}, W={w}") | |
# Check if any drawing layer exists and is not None | |
if not im_dict["layers"] or im_dict["layers"][0] is None: | |
print("No drawing layer found. Generating blank mask.") | |
# Nothing drawn yet, return a black mask of the original size | |
mask = np.zeros((h, w), dtype=np.uint8) | |
filepath = None # No file to download as nothing was drawn | |
else: | |
# Use the first layer (index 0) which usually contains the drawing | |
layer = im_dict["layers"][0] | |
print(f"Drawing layer dimensions: H={layer.shape[0]}, W={layer.shape[1]}") | |
# Ensure layer dimensions match background (Gradio ImageEditor usually handles this) | |
if layer.shape[0] != h or layer.shape[1] != w: | |
print(f"Warning: Layer size ({layer.shape[0]}x{layer.shape[1]}) doesn't match background ({h}x{w}). This shouldn't happen.") | |
# Handle potential mismatch if necessary, though unlikely with default editor behavior | |
# For now, proceed assuming they match or the layer is the correct reference | |
# Layer is RGBA, extract the Alpha channel (index 3) | |
alpha_channel = layer[:, :, 3] | |
# Create binary mask: white (255) where alpha > 0 (drawn), black (0) otherwise | |
mask = np.where(alpha_channel > 0, 255, 0).astype(np.uint8) | |
print(f"Generated binary mask dimensions: H={mask.shape[0]}, W={mask.shape[1]}") | |
# Save the mask to a temporary PNG file for download | |
try: | |
# Create a temporary file path | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: | |
filepath = tmpfile.name | |
# Save the NumPy array as a PNG image using PIL | |
pil_image = Image.fromarray(mask) | |
pil_image.save(filepath, format="PNG") | |
print(f"Mask saved temporarily to: {filepath}") | |
except Exception as e: | |
print(f"Error saving mask to temporary file: {e}") | |
filepath = None # Indicate failure to save | |
# Return a blank mask in case of saving error | |
mask = np.zeros((h, w), dtype=np.uint8) | |
# Return the mask NumPy array for preview and the filepath for download | |
# The DownloadButton component will become active/functional if filepath is not None | |
return mask, filepath | |
# --- Gradio App Layout --- | |
with gr.Blocks() as demo: | |
gr.Markdown("## Binary Mask Generator") | |
gr.Markdown( | |
"Upload or paste an image. Use the brush tool (select it!) to draw the area " | |
"you want to mask. Click 'Generate Mask' to see the result and download it." | |
) | |
with gr.Row(): | |
# --- Left Column --- | |
with gr.Column(scale=1): # Adjust scale as needed | |
image_editor = gr.ImageEditor( | |
label="Draw on Image", | |
# type="numpy" is essential for processing layers | |
type="numpy", | |
# DON'T set crop_size, height, or width to keep original dimensions | |
# sources allow upload, paste, webcam etc. | |
sources=["upload"], | |
# Set a default brush for clarity (optional, but helpful) | |
brush=gr.Brush(colors=["#FF0000"], color_mode="fixed"), # Red fixed brush | |
interactive=True, | |
canvas_size=(768, 1024) | |
) | |
generate_button = gr.Button("Generate Mask", variant="primary") | |
# --- Right Column --- | |
with gr.Column(scale=1): # Adjust scale as needed | |
mask_preview = gr.Image( | |
label="Binary Mask Preview", | |
# Use numpy for consistency, PIL would also work | |
type="numpy", | |
interactive=False, # Preview is not interactive | |
) | |
# Download button - its value (the file path) is set by the function's output | |
download_button = gr.DownloadButton( | |
label="Download Mask (PNG)", | |
interactive=True, # Button starts interactive | |
) | |
# --- Event Handling --- | |
generate_button.click( | |
fn=create_binary_mask, | |
inputs=[image_editor], | |
# Output 1 goes to mask_preview (image data) | |
# Output 2 goes to download_button (file path for the 'value' argument) | |
outputs=[mask_preview, download_button] | |
) | |
# --- Launch the App --- | |
if __name__ == "__main__": | |
# Cleaning up old temp files on startup (optional but good practice) | |
temp_dir = tempfile.gettempdir() | |
for item in os.listdir(temp_dir): | |
if item.endswith(".png") and item.startswith("tmp"): # Be specific to avoid deleting wrong files | |
try: | |
os.remove(os.path.join(temp_dir, item)) | |
except Exception: | |
pass # Ignore if file is locked etc. | |
demo.launch(share=True) | |