File size: 10,698 Bytes
9f37f40
 
bf6dd51
9f37f40
 
 
 
 
2c4b8a3
212f0f5
2c4b8a3
6535fce
2c4b8a3
9f37f40
 
 
 
 
 
 
 
 
 
 
386f67c
9f37f40
 
 
 
 
 
 
 
 
 
 
1e713b8
 
9f37f40
 
 
 
ce4fb35
9f37f40
 
29e81be
b961115
 
e7e717e
b961115
ce4fb35
 
 
2c4b8a3
58f6d0d
29e81be
2c4b8a3
29e81be
 
 
e886b5d
e7e717e
b961115
29e81be
2c4b8a3
 
b961115
 
e7e717e
b961115
 
 
 
 
 
 
2c4b8a3
 
 
 
 
 
b961115
 
 
65651fb
b961115
2c4b8a3
 
b961115
 
 
 
 
 
 
 
 
 
 
74d1fdf
2c4b8a3
b961115
2c4b8a3
 
b961115
 
 
74d1fdf
b961115
 
74d1fdf
b961115
 
74d1fdf
b961115
 
 
 
 
 
 
 
 
 
 
 
 
 
9f37f40
b961115
9f37f40
b961115
9f37f40
 
 
 
 
 
 
b961115
 
9f37f40
5643464
 
1aa4620
b961115
58e191c
9f37f40
 
 
2c4b8a3
 
 
9f37f40
e7e717e
caabd4b
 
 
 
 
 
 
 
 
 
 
 
58b9a20
 
caabd4b
 
e7e717e
caabd4b
9f37f40
caabd4b
 
 
 
 
 
34fcfd8
caabd4b
 
 
 
 
9f37f40
65651fb
ca021c4
 
 
65651fb
ca021c4
 
 
8fc1551
9879b4f
ce4fb35
 
29e81be
ca021c4
 
 
65651fb
ca021c4
 
ce4fb35
 
 
 
 
ca021c4
677bda2
65651fb
ca021c4
 
9f37f40
02e5a68
ce4fb35
02e5a68
 
8f8b4f5
 
 
 
 
 
 
 
 
 
 
 
 
2c4b4cb
ba1d3f9
 
cd4e87d
 
 
34fcfd8
b59122a
ba1d3f9
 
8271902
9eff9e4
ba1d3f9
7f7a2e1
ba1d3f9
65651fb
ba1d3f9
 
65651fb
 
677bda2
65651fb
677bda2
 
65651fb
7f7a2e1
677bda2
ba1d3f9
 
 
2c4b8a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import random
from dataclasses import dataclass
from typing import Any, List, Dict, Optional, Union, Tuple
import cv2
import torch
import requests
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
import gradio as gr
import spaces
import json

@dataclass
class BoundingBox:
    xmin: int
    ymin: int
    xmax: int
    ymax: int

    @property
    def xyxy(self) -> List[float]:
        return [self.xmin, self.ymin, self.xmax, self.ymax]

@dataclass
class DetectionResult:
    score: float
    label: str
    box: BoundingBox
    mask: Optional[np.ndarray] = None

    @classmethod
    def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
        return cls(
            score=detection_dict['score'],
            label=detection_dict['label'],
            box=BoundingBox(
                xmin=detection_dict['box']['xmin'],
                ymin=detection_dict['box']['ymin'],
                xmax=detection_dict['box']['xmax'],
                ymax=detection_dict['box']['ymax']
            )
        )

def annotate(image: Union[Image.Image, np.ndarray], detection_results: List[DetectionResult], include_bboxes: bool = True) -> np.ndarray:
    image_cv2 = np.array(image) if isinstance(image, Image.Image) else image
    image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_RGB2BGR)

    for detection in detection_results:
        label = detection.label
        score = detection.score
        box = detection.box
        mask = detection.mask

        if include_bboxes:
            color = np.random.randint(0, 256, size=3).tolist()
            cv2.rectangle(image_cv2, (box.xmin, box.ymin), (box.xmax, box.ymax), color, 2)
            cv2.putText(image_cv2, f'{label}: {score:.2f}', (box.xmin, box.ymin - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

    return cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)

def plot_detections(image: Union[Image.Image, np.ndarray], detections: List[DetectionResult], include_bboxes: bool = True) -> np.ndarray:
    annotated_image = annotate(image, detections, include_bboxes)
    return annotated_image

def load_image(image: Union[str, Image.Image]) -> Image.Image:
    if isinstance(image, str) and image.startswith("http"):
        image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
    elif isinstance(image, str):
        image = Image.open(image).convert("RGB")
    else:
        image = image.convert("RGB")
    return image

def get_boxes(detection_results: List[DetectionResult]) -> List[List[List[float]]]:
    boxes = []
    for result in detection_results:
        xyxy = result.box.xyxy
        boxes.append(xyxy)
    return [boxes]

def mask_to_polygon(mask: np.ndarray) -> np.ndarray:
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if len(contours) == 0:
        return np.array([])
    largest_contour = max(contours, key=cv2.contourArea)
    return largest_contour

def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
    masks = masks.cpu().float().permute(0, 2, 3, 1).mean(axis=-1).numpy().astype(np.uint8)
    masks = (masks > 0).astype(np.uint8)
    if polygon_refinement:
        for idx, mask in enumerate(masks):
            shape = mask.shape
            polygon = mask_to_polygon(mask)
            masks[idx] = cv2.fillPoly(np.zeros(shape, dtype=np.uint8), [polygon], 1)
    return list(masks)

@spaces.GPU
def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None) -> List[Dict[str, Any]]:
    detector_id = detector_id if detector_id else "IDEA-Research/grounding-dino-base"
    object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device="cuda")
    labels = [label if label.endswith(".") else label+"." for label in labels]
    results = object_detector(image, candidate_labels=labels, threshold=threshold)
    return [DetectionResult.from_dict(result) for result in results]

@spaces.GPU
def segment(image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False, segmenter_id: Optional[str] = None) -> List[DetectionResult]:
    segmenter_id = segmenter_id if segmenter_id else "martintmv/InsectSAM"
    segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to("cuda")
    processor = AutoProcessor.from_pretrained(segmenter_id)
    boxes = get_boxes(detection_results)
    inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to("cuda")
    outputs = segmentator(**inputs)
    masks = processor.post_process_masks(masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes)[0]
    masks = refine_masks(masks, polygon_refinement)
    for detection_result, mask in zip(detection_results, masks):
        detection_result.mask = mask
    return detection_results

def grounded_segmentation(image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False, detector_id: Optional[str] = None, segmenter_id: Optional[str] = None) -> Tuple[np.ndarray, List[DetectionResult]]:
    image = load_image(image)
    detections = detect(image, labels, threshold, detector_id)
    detections = segment(image, detections, polygon_refinement, segmenter_id)
    return np.array(image), detections

def mask_to_min_max(mask: np.ndarray) -> Tuple[int, int, int, int]:
    y, x = np.where(mask)
    return x.min(), y.min(), x.max(), y.max()

def extract_and_paste_insect(original_image: np.ndarray, detection: DetectionResult, background: np.ndarray) -> None:
    mask = detection.mask
    xmin, ymin, xmax, ymax = mask_to_min_max(mask)
    insect_crop = original_image[ymin:ymax, xmin:xmax]
    mask_crop = mask[ymin:ymax, xmin:xmax]

    insect = cv2.bitwise_and(insect_crop, insect_crop, mask=mask_crop)

    x_offset, y_offset = xmin, ymin
    x_end, y_end = x_offset + insect.shape[1], y_offset + insect.shape[0]

    insect_area = background[y_offset:y_end, x_offset:x_end]
    insect_area[mask_crop == 1] = insect[mask_crop == 1]

def create_yellow_background_with_insects(image: np.ndarray, detections: List[DetectionResult]) -> np.ndarray:
    yellow_background = np.full((image.shape[0], image.shape[1], 3), (0, 255, 255), dtype=np.uint8)  # BGR for yellow
    for detection in detections:
        if detection.mask is not None:
            extract_and_paste_insect(image, detection, yellow_background)
    # Convert back to RGB to match Gradio's expected input format
    yellow_background = cv2.cvtColor(yellow_background, cv2.COLOR_BGR2RGB)
    return yellow_background

def run_length_encoding(mask):
    pixels = mask.flatten()
    rle = []
    last_val = 0
    count = 0
    for pixel in pixels:
        if pixel == last_val:
            count += 1
        else:
            if count > 0:
                rle.append(count)
            count = 1
            last_val = pixel
    if count > 0:
        rle.append(count)
    return rle

def detections_to_json(detections):
    detections_list = []
    for detection in detections:
        detection_dict = {
            "score": detection.score,
            "label": detection.label,
            "box": {
                "xmin": detection.box.xmin,
                "ymin": detection.box.ymin,
                "xmax": detection.box.xmax
            },
            "mask": run_length_encoding(detection.mask) if detection.mask is not None else None
        }
        detections_list.append(detection_dict)
    return detections_list

def crop_bounding_boxes_with_yellow_background(image: np.ndarray, yellow_background: np.ndarray, detections: List[DetectionResult]) -> List[np.ndarray]:
    crops = []
    for detection in detections:
        xmin, ymin, xmax, ymax = detection.box.xyxy
        crop = yellow_background[ymin:ymax, xmin:xmax]
        crops.append(crop)
    return crops

def process_image(image, include_json, include_bboxes):
    labels = ["insect"]
    original_image, detections = grounded_segmentation(image, labels, threshold=0.3, polygon_refinement=True)
    yellow_background_with_insects = create_yellow_background_with_insects(np.array(original_image), detections)
    annotated_image = plot_detections(yellow_background_with_insects, detections, include_bboxes)

    results = [annotated_image]
    if include_bboxes:
        crops = crop_bounding_boxes_with_yellow_background(np.array(original_image), yellow_background_with_insects, detections)
        results.extend(crops)

    if include_json:
        detections_json = detections_to_json(detections)
        json_output_path = "insect_detections.json"
        with open(json_output_path, 'w') as json_file:
            json.dump(detections_json, json_file, indent=4)
        results.append(json.dumps(detections_json, separators=(',', ':')))
    elif not include_bboxes:
        results.append(None)
    
    return tuple(results)

examples = [
    ["flower-night.jpg"]
]

css = """
.checkbox-group {
    display: flex;
    justify-content: center;
    gap: 20px;
    margin-bottom: 20px;
}
.checkbox-group .gr-checkbox {
    width: auto;
}
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("InsectSAM 🐞 Detect and Segment Insects in Datasets")
    with gr.Row():
        image_input = gr.Image(type="pil")
        with gr.Column():
            include_json = gr.Checkbox(label="Include JSON", value=False, elem_id="checkbox-group")
            include_bboxes = gr.Checkbox(label="Include Bounding Boxes", value=False, elem_id="checkbox-group")
            gr.Examples(examples=examples, inputs=[image_input, include_json, include_bboxes])
        submit_button = gr.Button("Submit")

    annotated_output = gr.Image(type="numpy")
    json_output = gr.Textbox(label="JSON")
    crops_output = gr.Gallery(label="Cropped Bounding Boxes")

    async def update_outputs(image, include_json, include_bboxes):
        results = process_image(image, include_json, include_bboxes)
        if include_bboxes and include_json:
            annotated_img, *crops, json_txt = results
            return (annotated_img, json_txt, crops)
        elif include_bboxes:
            annotated_img, *crops = results
            return (annotated_img, None, crops)
        elif include_json:
            annotated_img, json_txt = results
            return (annotated_img, json_txt, [])
        else:
            annotated_img = results[0]
            return (annotated_img, None, [])

    submit_button.click(update_outputs, [image_input, include_json, include_bboxes], [annotated_output, json_output, crops_output])

demo.launch()