File size: 5,350 Bytes
f076cb6
 
c8757ee
 
 
 
 
 
 
 
f076cb6
c8757ee
 
 
 
 
f076cb6
 
c8757ee
 
 
f076cb6
 
 
c8757ee
f076cb6
 
 
 
c8757ee
 
 
 
 
f076cb6
 
ccb7bbe
 
 
 
 
 
 
 
 
f076cb6
ccb7bbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7de04d2
 
ccb7bbe
7de04d2
 
 
ccb7bbe
187d444
 
 
 
 
7de04d2
ccb7bbe
7de04d2
 
 
 
 
ccb7bbe
7de04d2
 
 
 
 
 
 
 
 
 
 
 
ccb7bbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7de04d2
 
ccb7bbe
 
7de04d2
 
 
ccb7bbe
7de04d2
 
 
 
ccb7bbe
 
 
 
 
 
 
7de04d2
ccb7bbe
 
 
 
 
 
4c6b11a
ccb7bbe
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import subprocess

import os
import sys
import subprocess

def run(cmd, cwd=None):
    print(f"▶ {cmd}")
    subprocess.check_call(cmd, shell=True, cwd=cwd)

def setup_deps():
    # Use a flag to prevent infinite restarts
    if os.environ.get("HF_SPACE_BOOTSTRAPPED") == "1":
        return

    # Try importing something to check if it's already set up
    try:
        import torch
        import sam2
        print("🔧 Dependencies already installed.")
        return  # all good, don't reinstall
    except ImportError:
        pass

    print("🔧 Installing dependencies...")
    run("pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu")
    run("pip install -e .", cwd="segment-anything-2")
    run("pip install --no-deps -r requirements_manual.txt")

    # Relaunch the script with an env flag to avoid looping
    print("♻️ Restarting app to apply changes...")
    os.environ["HF_SPACE_BOOTSTRAPPED"] = "1"
    os.execv(sys.executable, [sys.executable] + sys.argv)

setup_deps()

import gradio as gr
import numpy as np
from PIL import Image
import sam_utils
import matplotlib.pyplot as plt
from io import BytesIO

from sam2.sam2_image_predictor import SAM2ImagePredictor


# Dummy placeholders for SAM2 functions (replace with real logic)
def segment_reference(image, click):
    # click = [x, y]
    # Replace this with your SAM2 model's inference logic
    # Return a binary mask (numpy array with shape [H, W], values 0 or 1)
    print(f"Segmenting reference at point: {click}")
    width, height = image.size
    click = np.array(click)
    input_label = np.array([1 for _ in range(len(click))])
    sam2_img.set_image(image)

    masks, _, _ = sam2_img.predict(
        point_coords=click,
        point_labels=input_label,
        multimask_output=False,
    )

    return masks

def segment_target(target_images, ref_image, ref_mask):
    target_images = [np.array(target_image) for target_image in target_images]
    ref_image = np.array(ref_image)
    state = sam_utils.load_masks(sam2_vid, target_images, ref_image, ref_mask)
    out = sam_utils.propagate_masks(sam2_vid, state)[1:]
    return [mask['segmentation'] for mask in out]

def on_reference_upload(img):
    global click_coords
    click_coords = []  # clear the clicks
    return "Click Info: Cleared (new image uploaded)"

def visualize_segmentation(image, masks, target_images, target_masks):
    # Visualize the segmentation result
    num_tgt = len(target_images)
    fig, ax = plt.subplots(2, num_tgt, figsize=(6*num_tgt, 12))
    if num_tgt == 1:
        ax = np.expand_dims(ax, axis=1)
    ax[0][0].imshow(image.convert("L"), cmap='gray')
    for i, mask in enumerate(masks):
        sam_utils.show_mask(mask, ax[0][0], obj_id=i, alpha=0.75)
    ax[0][0].axis('off')
    ax[0][0].set_title("Reference Image with Expert Segmentation")
    for i in range(1, num_tgt):
        # set the rest to empty
        ax[0][i].axis('off')
    for i in range(num_tgt):
        ax[1][i].imshow(target_images[i].convert("L"), cmap='gray')
        for j, mask in enumerate(target_masks[i]):
            sam_utils.show_mask(mask, ax[1][i], obj_id=j, alpha=0.75)
        ax[1][i].axis('off')
        ax[1][i].set_title("Target Image with Inferred Segmentation")
    # save it to buffer
    plt.tight_layout()
    buf = BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    vis = Image.open(buf).copy()
    plt.close(fig)
    buf.close()
    return vis

# Store click coords globally (can be improved with state)
click_coords = []

def record_click(img, evt: gr.SelectData):
    global click_coords
    click_coords.append([evt.index[0], evt.index[1]])
    return f"Clicked at: {click_coords}"

def generate(reference_image, target_images):
    global click_coords
    if not click_coords:
        return None, "Click on the reference image first!"
    
    target_images = [Image.open(f.name).convert("RGB").resize((1024,1024)) for f in target_images]

    ref_mask = segment_reference(reference_image, click_coords)
    tgt_masks = segment_target(target_images, reference_image, ref_mask)
    vis = visualize_segmentation(reference_image, ref_mask, target_images, tgt_masks)
    # clear the clicks
    click_coords = []
    return vis, "Done!"

with gr.Blocks() as demo:
    gr.Markdown("### SST Demo: Label-Efficient Trait Segmentation")
    
    with gr.Row():
        reference_img = gr.Image(type="pil", label="Reference Image")
        target_img = gr.File(file_types=["image"], file_count="multiple", label="Target Images")
    
    click_info = gr.Textbox(label="Click Info")
    generate_btn = gr.Button("Generate")
    output_mask = gr.Image(type="pil", label="Generated Mask")

    reference_img.select(fn=record_click, inputs=[reference_img], outputs=[click_info])
    reference_img.change(fn=on_reference_upload, inputs=[reference_img], outputs=[click_info])
    generate_btn.click(fn=generate, inputs=[reference_img, target_img], outputs=[output_mask, click_info])

global sam2_img
sam2_img = sam_utils.load_SAM2(ckpt_path="checkpoints/sam2_hiera_small.pt", model_cfg_path="checkpoints/sam2_hiera_s.yaml")
sam2_img = SAM2ImagePredictor(sam2_img)
global sam2_vid
sam2_vid = sam_utils.build_sam2_predictor(checkpoint="checkpoints/sam2_hiera_small.pt", model_cfg="checkpoints/sam2_hiera_s.yaml")
demo.launch()