File size: 7,315 Bytes
26c3b64
d0503eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
#os.system('cd GroundingDINO && pip install -e. && cd .. && cd segment_anything && pip install -e. && cd ..')
import cv2
import gradio as gr
from PIL import Image
import numpy as np
from sam_extension.utils import add_points_tag, add_boxes_tag, mask2greyimg
from sam_extension.pipeline import SAMEncoderPipeline, SAMDecoderPipeline, GroundingDinoPipeline
point_coords = []
point_labels = []
boxes = []
boxes_point = []
texts = []
sam_encoder_pipeline = None
sam_decoder_pipeline = None
result_list = []
result_index_list = []
mask_result_list = []
mask_result_index_list = []
def resize(image, des_max=512):
    h, w = image.shape[:2]
    if h >= w:
        new_h = des_max
        new_w = int(des_max * w / h)
    else:
        new_w = des_max
        new_h = int(des_max * h / w)
    return cv2.resize(image, (new_w, new_h))
def show_prompt(img, prompt_mode, pos_point, evt: gr.SelectData):  # SelectData is a subclass of EventData
    global point_coords, point_labels, boxes_point, boxes
    if prompt_mode == 'point':
        point_coords.append([evt.index[0], evt.index[1]])
        point_labels.append(1 if pos_point else 0)
        result_img = add_points_tag(img, np.array(point_labels), np.array(point_coords))
    elif prompt_mode == 'box':
        boxes_point.append(evt.index[0])
        boxes_point.append(evt.index[1])
        if len(boxes_point) == 4:
            boxes.append(boxes_point)
            boxes_point = []
        result_img = add_boxes_tag(img, np.array(boxes))
    else:
        result_img = img
    return result_img, point_coords, point_labels, boxes_point, boxes

def reset_points(img):
    global point_coords, point_labels
    point_coords = []
    point_labels = []
    return img, point_coords, point_labels


def reset_boxes(img):
    global boxes_point, boxes
    boxes_point = []
    boxes = []
    return img, boxes_point, boxes

def load_sam(sam_ckpt_path, sam_version):
    global sam_encoder_pipeline, sam_decoder_pipeline
    sam_encoder_pipeline = SAMEncoderPipeline.from_pretrained(ckpt_path=sam_ckpt_path, sam_version=sam_version, device='cpu')
    sam_decoder_pipeline = SAMDecoderPipeline.from_pretrained(ckpt_path=sam_ckpt_path, sam_version=sam_version, device='cpu')
    return 'sam loaded!'


def generate_mask(img, prompt_mode, text_prompt):
    global result_list, mask_result_list, result_index_list, mask_result_index_list
    image = Image.fromarray(img)
    img_size = sam_decoder_pipeline.img_size
    des_img = image.resize((img_size, img_size))
    sam_encoder_output = sam_encoder_pipeline(des_img)
    if prompt_mode == 'point':
        point_coords_ = np.array(point_coords)
        point_labels_ = np.array(point_labels)
        boxes_ = None
        texts_ = None
        grounding_dino_pipeline = None
    elif prompt_mode == 'box':
        point_coords_ = None
        point_labels_ = None
        boxes_ = np.array(boxes)
        texts_ = None
        grounding_dino_pipeline = None
    else:
        point_coords_ = None
        point_labels_ = None
        boxes_ = None
        texts_ = text_prompt.split(',')
        grounding_dino_pipeline = GroundingDinoPipeline.from_pretrained(
            'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py',
            'weights/groundingdino/groundingdino_swint_ogc.pth',
            device='cpu')
    result_list, mask_result_list, masks_list = sam_decoder_pipeline.visualize_results(
        image,
        des_img,
        sam_encoder_output,
        point_coords=point_coords_,
        point_labels=point_labels_,
        boxes=boxes_,
        texts=texts_,
        grounding_dino_pipeline=grounding_dino_pipeline,
        multimask_output=True,
        visualize_promts=True,
        pil=False)
    # result_index_list = [f'result_{i}' for i in range(len(result_list))]
    # mask_result_index_list = [f'mask_result_{i}' for i in range(len(mask_result_list))]
    return 'mask generated!', f'result_num : {len(result_list)}', f'mask_result_num : {len(masks_list)}'
    # mask_grey_result_list = mask2greyimg(masks_list, False)


def show_result(result_index):
    return result_list[int(result_index)]


def show_mask_result(mask_result_index):
    return mask_result_list[int(mask_result_index)]


with gr.Blocks() as demo:
    with gr.Row():
        img = gr.Image(None, width=400, height=400, label='input_image', type='numpy')
        result_img = gr.Image(None, width=400, height=400, label='output_image', type='numpy')
    with gr.Row():
        pos_point = gr.Checkbox(value=True, label='pos_point')
        prompt_mode = gr.Dropdown(choices=['point', 'box', 'text'], value='point', label='prompt_mode')
    with gr.Row():
        point_coords_text = gr.Textbox(value=str(point_coords), interactive=True, label='point_coords')
        point_labels_text = gr.Textbox(value=str(point_labels), interactive=True, label='point_labels')
        reset_points_bu = gr.Button(value='reset_points')
        reset_points_bu.click(fn=reset_points, inputs=[img], outputs=[result_img, point_coords_text, point_labels_text])
    with gr.Row():
        boxes_point_text = gr.Textbox(value=str(boxes_point), interactive=True, label='boxes_point')
        boxes_text = gr.Textbox(value=str(boxes), interactive=True, label='boxes')
        reset_boxes_bu = gr.Button(value='reset_boxes')
        reset_boxes_bu.click(fn=reset_boxes, inputs=[img], outputs=[result_img, boxes_point_text, boxes_text])
    with gr.Row():
        text_prompt = gr.Textbox(value='', interactive=True, label='text_prompt')
    with gr.Row():
        sam_ckpt_path = gr.Dropdown(choices=['weights/sam/mobile_sam.pt'],
                                    value='weights/sam/mobile_sam.pt',
                                    label='SAM ckpt_path')
        sam_version = gr.Dropdown(choices=['mobile_sam'],
                                  value='mobile_sam',
                                  label='SAM version')
        load_sam_bu = gr.Button(value='load SAM')
        sam_load_text = gr.Textbox(value='', interactive=True, label='sam_load')
        load_sam_bu.click(fn=load_sam, inputs=[sam_ckpt_path, sam_version], outputs=sam_load_text)
    with gr.Row():
        result_num_text = gr.Textbox(value='', interactive=True, label='result_num')
        result_index = gr.Number(value=0, label='result_index')
        show_result_bu = gr.Button(value='show_result')
        show_result_bu.click(fn=show_result, inputs=[result_index], outputs=[result_img])
    with gr.Row():
        mask_result_num_text = gr.Textbox(value='', interactive=True, label='mask_result_num')
        mask_result_index = gr.Number(value=0, label='mask_result_index')
        show_mask_result_bu = gr.Button(value='show_mask_result')
        show_mask_result_bu.click(fn=show_mask_result, inputs=[mask_result_index], outputs=[result_img])
    with gr.Row():
        generate_masks_bu = gr.Button(value='SAM generate masks')
        sam_text = gr.Textbox(value='', interactive=True, label='SAM')
        generate_masks_bu.click(fn=generate_mask, inputs=[img, prompt_mode, text_prompt], outputs=[sam_text, result_num_text, mask_result_num_text])
    img.select(show_prompt, [img, prompt_mode, pos_point], [result_img, point_coords_text, point_labels_text, boxes_point_text, boxes_text])
if __name__ == '__main__':
    demo.launch()