File size: 5,453 Bytes
b3460e8
71d1134
 
 
 
b3460e8
 
71d1134
 
 
 
 
 
 
b3460e8
71d1134
 
 
b3460e8
71d1134
 
 
 
 
 
 
 
 
 
b3460e8
71d1134
 
 
 
 
b3460e8
71d1134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdba3b8
71d1134
 
 
 
 
 
b3460e8
71d1134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdba3b8
71d1134
 
 
 
 
 
 
 
 
 
bdba3b8
71d1134
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import numpy as np
from PIL import Image
import tempfile
import os


# Function to create the binary mask from the ImageEditor's output
def create_binary_mask(im_dict):
    if im_dict is None or im_dict["background"] is None:
        print("No background image found.")
        # Return a small blank placeholder and None for the file path
        blank_preview = np.zeros((768, 1024), dtype=np.uint8)
        return blank_preview, None

    background_img = im_dict["background"]
    h, w, _ = background_img.shape # Get original dimensions (Height, Width, Channels)
    print(f"Original image dimensions: H={h}, W={w}")

    # Check if any drawing layer exists and is not None
    if not im_dict["layers"] or im_dict["layers"][0] is None:
        print("No drawing layer found. Generating blank mask.")
        # Nothing drawn yet, return a black mask of the original size
        mask = np.zeros((h, w), dtype=np.uint8)
        filepath = None # No file to download as nothing was drawn
    else:
        # Use the first layer (index 0) which usually contains the drawing
        layer = im_dict["layers"][0]
        print(f"Drawing layer dimensions: H={layer.shape[0]}, W={layer.shape[1]}")

        # Ensure layer dimensions match background (Gradio ImageEditor usually handles this)
        if layer.shape[0] != h or layer.shape[1] != w:
            print(f"Warning: Layer size ({layer.shape[0]}x{layer.shape[1]}) doesn't match background ({h}x{w}). This shouldn't happen.")
            # Handle potential mismatch if necessary, though unlikely with default editor behavior
            # For now, proceed assuming they match or the layer is the correct reference

        # Layer is RGBA, extract the Alpha channel (index 3)
        alpha_channel = layer[:, :, 3]

        # Create binary mask: white (255) where alpha > 0 (drawn), black (0) otherwise
        mask = np.where(alpha_channel > 0, 255, 0).astype(np.uint8)
        print(f"Generated binary mask dimensions: H={mask.shape[0]}, W={mask.shape[1]}")

        # Save the mask to a temporary PNG file for download
        try:
            # Create a temporary file path
            with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
                filepath = tmpfile.name
            
            # Save the NumPy array as a PNG image using PIL
            pil_image = Image.fromarray(mask)
            pil_image.save(filepath, format="PNG")
            print(f"Mask saved temporarily to: {filepath}")

        except Exception as e:
            print(f"Error saving mask to temporary file: {e}")
            filepath = None # Indicate failure to save
            # Return a blank mask in case of saving error
            mask = np.zeros((h, w), dtype=np.uint8)

    # Return the mask NumPy array for preview and the filepath for download
    # The DownloadButton component will become active/functional if filepath is not None
    return mask, filepath

# --- Gradio App Layout ---
with gr.Blocks() as demo:
    gr.Markdown("## Binary Mask Generator")
    gr.Markdown(
        "Upload or paste an image. Use the brush tool (select it!) to draw the area "
        "you want to mask. Click 'Generate Mask' to see the result and download it."
    )

    with gr.Row():
        # --- Left Column ---
        with gr.Column(scale=1): # Adjust scale as needed
            image_editor = gr.ImageEditor(
                label="Draw on Image",
                # type="numpy" is essential for processing layers
                type="numpy",
                # DON'T set crop_size, height, or width to keep original dimensions
                # sources allow upload, paste, webcam etc.
                sources=["upload"],
                # Set a default brush for clarity (optional, but helpful)
                brush=gr.Brush(colors=["#FF0000"], color_mode="fixed"), # Red fixed brush
                interactive=True,
                canvas_size=(768, 1024)
            )
            generate_button = gr.Button("Generate Mask", variant="primary")

        # --- Right Column ---
        with gr.Column(scale=1): # Adjust scale as needed
            mask_preview = gr.Image(
                label="Binary Mask Preview",
                # Use numpy for consistency, PIL would also work
                type="numpy",
                interactive=False, # Preview is not interactive
            )
            # Download button - its value (the file path) is set by the function's output
            download_button = gr.DownloadButton(
                label="Download Mask (PNG)",
                interactive=True, # Button starts interactive
            )

    # --- Event Handling ---
    generate_button.click(
        fn=create_binary_mask,
        inputs=[image_editor],
        # Output 1 goes to mask_preview (image data)
        # Output 2 goes to download_button (file path for the 'value' argument)
        outputs=[mask_preview, download_button]
    )

# --- Launch the App ---
if __name__ == "__main__":
    # Cleaning up old temp files on startup (optional but good practice)
    temp_dir = tempfile.gettempdir()
    for item in os.listdir(temp_dir):
        if item.endswith(".png") and item.startswith("tmp"): # Be specific to avoid deleting wrong files
             try:
                 os.remove(os.path.join(temp_dir, item))
             except Exception:
                 pass # Ignore if file is locked etc.

    demo.launch(share=True)