Spaces:
Running
Running
File size: 7,315 Bytes
26c3b64 d0503eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import os
#os.system('cd GroundingDINO && pip install -e. && cd .. && cd segment_anything && pip install -e. && cd ..')
import cv2
import gradio as gr
from PIL import Image
import numpy as np
from sam_extension.utils import add_points_tag, add_boxes_tag, mask2greyimg
from sam_extension.pipeline import SAMEncoderPipeline, SAMDecoderPipeline, GroundingDinoPipeline
point_coords = []
point_labels = []
boxes = []
boxes_point = []
texts = []
sam_encoder_pipeline = None
sam_decoder_pipeline = None
result_list = []
result_index_list = []
mask_result_list = []
mask_result_index_list = []
def resize(image, des_max=512):
h, w = image.shape[:2]
if h >= w:
new_h = des_max
new_w = int(des_max * w / h)
else:
new_w = des_max
new_h = int(des_max * h / w)
return cv2.resize(image, (new_w, new_h))
def show_prompt(img, prompt_mode, pos_point, evt: gr.SelectData): # SelectData is a subclass of EventData
global point_coords, point_labels, boxes_point, boxes
if prompt_mode == 'point':
point_coords.append([evt.index[0], evt.index[1]])
point_labels.append(1 if pos_point else 0)
result_img = add_points_tag(img, np.array(point_labels), np.array(point_coords))
elif prompt_mode == 'box':
boxes_point.append(evt.index[0])
boxes_point.append(evt.index[1])
if len(boxes_point) == 4:
boxes.append(boxes_point)
boxes_point = []
result_img = add_boxes_tag(img, np.array(boxes))
else:
result_img = img
return result_img, point_coords, point_labels, boxes_point, boxes
def reset_points(img):
global point_coords, point_labels
point_coords = []
point_labels = []
return img, point_coords, point_labels
def reset_boxes(img):
global boxes_point, boxes
boxes_point = []
boxes = []
return img, boxes_point, boxes
def load_sam(sam_ckpt_path, sam_version):
global sam_encoder_pipeline, sam_decoder_pipeline
sam_encoder_pipeline = SAMEncoderPipeline.from_pretrained(ckpt_path=sam_ckpt_path, sam_version=sam_version, device='cpu')
sam_decoder_pipeline = SAMDecoderPipeline.from_pretrained(ckpt_path=sam_ckpt_path, sam_version=sam_version, device='cpu')
return 'sam loaded!'
def generate_mask(img, prompt_mode, text_prompt):
global result_list, mask_result_list, result_index_list, mask_result_index_list
image = Image.fromarray(img)
img_size = sam_decoder_pipeline.img_size
des_img = image.resize((img_size, img_size))
sam_encoder_output = sam_encoder_pipeline(des_img)
if prompt_mode == 'point':
point_coords_ = np.array(point_coords)
point_labels_ = np.array(point_labels)
boxes_ = None
texts_ = None
grounding_dino_pipeline = None
elif prompt_mode == 'box':
point_coords_ = None
point_labels_ = None
boxes_ = np.array(boxes)
texts_ = None
grounding_dino_pipeline = None
else:
point_coords_ = None
point_labels_ = None
boxes_ = None
texts_ = text_prompt.split(',')
grounding_dino_pipeline = GroundingDinoPipeline.from_pretrained(
'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py',
'weights/groundingdino/groundingdino_swint_ogc.pth',
device='cpu')
result_list, mask_result_list, masks_list = sam_decoder_pipeline.visualize_results(
image,
des_img,
sam_encoder_output,
point_coords=point_coords_,
point_labels=point_labels_,
boxes=boxes_,
texts=texts_,
grounding_dino_pipeline=grounding_dino_pipeline,
multimask_output=True,
visualize_promts=True,
pil=False)
# result_index_list = [f'result_{i}' for i in range(len(result_list))]
# mask_result_index_list = [f'mask_result_{i}' for i in range(len(mask_result_list))]
return 'mask generated!', f'result_num : {len(result_list)}', f'mask_result_num : {len(masks_list)}'
# mask_grey_result_list = mask2greyimg(masks_list, False)
def show_result(result_index):
return result_list[int(result_index)]
def show_mask_result(mask_result_index):
return mask_result_list[int(mask_result_index)]
with gr.Blocks() as demo:
with gr.Row():
img = gr.Image(None, width=400, height=400, label='input_image', type='numpy')
result_img = gr.Image(None, width=400, height=400, label='output_image', type='numpy')
with gr.Row():
pos_point = gr.Checkbox(value=True, label='pos_point')
prompt_mode = gr.Dropdown(choices=['point', 'box', 'text'], value='point', label='prompt_mode')
with gr.Row():
point_coords_text = gr.Textbox(value=str(point_coords), interactive=True, label='point_coords')
point_labels_text = gr.Textbox(value=str(point_labels), interactive=True, label='point_labels')
reset_points_bu = gr.Button(value='reset_points')
reset_points_bu.click(fn=reset_points, inputs=[img], outputs=[result_img, point_coords_text, point_labels_text])
with gr.Row():
boxes_point_text = gr.Textbox(value=str(boxes_point), interactive=True, label='boxes_point')
boxes_text = gr.Textbox(value=str(boxes), interactive=True, label='boxes')
reset_boxes_bu = gr.Button(value='reset_boxes')
reset_boxes_bu.click(fn=reset_boxes, inputs=[img], outputs=[result_img, boxes_point_text, boxes_text])
with gr.Row():
text_prompt = gr.Textbox(value='', interactive=True, label='text_prompt')
with gr.Row():
sam_ckpt_path = gr.Dropdown(choices=['weights/sam/mobile_sam.pt'],
value='weights/sam/mobile_sam.pt',
label='SAM ckpt_path')
sam_version = gr.Dropdown(choices=['mobile_sam'],
value='mobile_sam',
label='SAM version')
load_sam_bu = gr.Button(value='load SAM')
sam_load_text = gr.Textbox(value='', interactive=True, label='sam_load')
load_sam_bu.click(fn=load_sam, inputs=[sam_ckpt_path, sam_version], outputs=sam_load_text)
with gr.Row():
result_num_text = gr.Textbox(value='', interactive=True, label='result_num')
result_index = gr.Number(value=0, label='result_index')
show_result_bu = gr.Button(value='show_result')
show_result_bu.click(fn=show_result, inputs=[result_index], outputs=[result_img])
with gr.Row():
mask_result_num_text = gr.Textbox(value='', interactive=True, label='mask_result_num')
mask_result_index = gr.Number(value=0, label='mask_result_index')
show_mask_result_bu = gr.Button(value='show_mask_result')
show_mask_result_bu.click(fn=show_mask_result, inputs=[mask_result_index], outputs=[result_img])
with gr.Row():
generate_masks_bu = gr.Button(value='SAM generate masks')
sam_text = gr.Textbox(value='', interactive=True, label='SAM')
generate_masks_bu.click(fn=generate_mask, inputs=[img, prompt_mode, text_prompt], outputs=[sam_text, result_num_text, mask_result_num_text])
img.select(show_prompt, [img, prompt_mode, pos_point], [result_img, point_coords_text, point_labels_text, boxes_point_text, boxes_text])
if __name__ == '__main__':
demo.launch() |