Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import tempfile | |
| import os | |
| # Function to create the binary mask from the ImageEditor's output | |
| def create_binary_mask(im_dict): | |
| if im_dict is None or im_dict["background"] is None: | |
| print("No background image found.") | |
| # Return a small blank placeholder and None for the file path | |
| blank_preview = np.zeros((768, 1024), dtype=np.uint8) | |
| return blank_preview, None | |
| background_img = im_dict["background"] | |
| h, w, _ = background_img.shape # Get original dimensions (Height, Width, Channels) | |
| print(f"Original image dimensions: H={h}, W={w}") | |
| # Check if any drawing layer exists and is not None | |
| if not im_dict["layers"] or im_dict["layers"][0] is None: | |
| print("No drawing layer found. Generating blank mask.") | |
| # Nothing drawn yet, return a black mask of the original size | |
| mask = np.zeros((h, w), dtype=np.uint8) | |
| filepath = None # No file to download as nothing was drawn | |
| else: | |
| # Use the first layer (index 0) which usually contains the drawing | |
| layer = im_dict["layers"][0] | |
| print(f"Drawing layer dimensions: H={layer.shape[0]}, W={layer.shape[1]}") | |
| # Ensure layer dimensions match background (Gradio ImageEditor usually handles this) | |
| if layer.shape[0] != h or layer.shape[1] != w: | |
| print(f"Warning: Layer size ({layer.shape[0]}x{layer.shape[1]}) doesn't match background ({h}x{w}). This shouldn't happen.") | |
| # Handle potential mismatch if necessary, though unlikely with default editor behavior | |
| # For now, proceed assuming they match or the layer is the correct reference | |
| # Layer is RGBA, extract the Alpha channel (index 3) | |
| alpha_channel = layer[:, :, 3] | |
| # Create binary mask: white (255) where alpha > 0 (drawn), black (0) otherwise | |
| mask = np.where(alpha_channel > 0, 255, 0).astype(np.uint8) | |
| print(f"Generated binary mask dimensions: H={mask.shape[0]}, W={mask.shape[1]}") | |
| # Save the mask to a temporary PNG file for download | |
| try: | |
| # Create a temporary file path | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: | |
| filepath = tmpfile.name | |
| # Save the NumPy array as a PNG image using PIL | |
| pil_image = Image.fromarray(mask) | |
| pil_image.save(filepath, format="PNG") | |
| print(f"Mask saved temporarily to: {filepath}") | |
| except Exception as e: | |
| print(f"Error saving mask to temporary file: {e}") | |
| filepath = None # Indicate failure to save | |
| # Return a blank mask in case of saving error | |
| mask = np.zeros((h, w), dtype=np.uint8) | |
| # Return the mask NumPy array for preview and the filepath for download | |
| # The DownloadButton component will become active/functional if filepath is not None | |
| return mask, filepath | |
| # --- Gradio App Layout --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Binary Mask Generator") | |
| gr.Markdown( | |
| "Upload or paste an image. Use the brush tool (select it!) to draw the area " | |
| "you want to mask. Click 'Generate Mask' to see the result and download it." | |
| ) | |
| with gr.Row(): | |
| # --- Left Column --- | |
| with gr.Column(scale=1): # Adjust scale as needed | |
| image_editor = gr.ImageEditor( | |
| label="Draw on Image", | |
| # type="numpy" is essential for processing layers | |
| type="numpy", | |
| # DON'T set crop_size, height, or width to keep original dimensions | |
| # sources allow upload, paste, webcam etc. | |
| sources=["upload"], | |
| # Set a default brush for clarity (optional, but helpful) | |
| brush=gr.Brush(colors=["#FF0000"], color_mode="fixed"), # Red fixed brush | |
| interactive=True, | |
| canvas_size=(768, 1024) | |
| ) | |
| generate_button = gr.Button("Generate Mask", variant="primary") | |
| # --- Right Column --- | |
| with gr.Column(scale=1): # Adjust scale as needed | |
| mask_preview = gr.Image( | |
| label="Binary Mask Preview", | |
| # Use numpy for consistency, PIL would also work | |
| type="numpy", | |
| interactive=False, # Preview is not interactive | |
| ) | |
| # Download button - its value (the file path) is set by the function's output | |
| download_button = gr.DownloadButton( | |
| label="Download Mask (PNG)", | |
| interactive=True, # Button starts interactive | |
| ) | |
| # --- Event Handling --- | |
| generate_button.click( | |
| fn=create_binary_mask, | |
| inputs=[image_editor], | |
| # Output 1 goes to mask_preview (image data) | |
| # Output 2 goes to download_button (file path for the 'value' argument) | |
| outputs=[mask_preview, download_button] | |
| ) | |
| # --- Launch the App --- | |
| if __name__ == "__main__": | |
| # Cleaning up old temp files on startup (optional but good practice) | |
| temp_dir = tempfile.gettempdir() | |
| for item in os.listdir(temp_dir): | |
| if item.endswith(".png") and item.startswith("tmp"): # Be specific to avoid deleting wrong files | |
| try: | |
| os.remove(os.path.join(temp_dir, item)) | |
| except Exception: | |
| pass # Ignore if file is locked etc. | |
| demo.launch(share=True) | |