import os import gradio as gr import numpy as np from PIL import Image import argparse import pathlib from torch.nn import functional as F from show import * from per_segment_anything import sam_model_registry, SamPredictor parser = argparse.ArgumentParser() parser.add_argument("-op", "--output-path", type=str, default='default') args = parser.parse_args() class ImageMask(gr.components.Image): """ Sets: source="canvas", tool="sketch" """ is_template = True def __init__(self, **kwargs): super().__init__(source="upload", tool='select', interactive=True, **kwargs) def preprocess(self, x): return super().preprocess(x) def point_selection(mask_sim, topk=1): # Top-1 point selection w, h = mask_sim.shape topk_xy = mask_sim.flatten(0).topk(topk)[1] topk_x = (topk_xy // h).unsqueeze(0) topk_y = (topk_xy - topk_x * h) topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0) topk_label = np.array([1] * topk) topk_xy = topk_xy.cpu().numpy() # Top-last point selection last_xy = mask_sim.flatten(0).topk(topk, largest=False)[1] last_x = (last_xy // h).unsqueeze(0) last_y = (last_xy - last_x * h) last_xy = torch.cat((last_y, last_x), dim=0).permute(1, 0) last_label = np.array([0] * topk) last_xy = last_xy.cpu().numpy() return topk_xy, topk_label, last_xy, last_label def inference_scribble(image): # in context image and mask ic_image = image["image"] ic_mask = image["mask"] ic_image = np.array(ic_image.convert("RGB")) ic_mask = np.array(ic_mask.convert("RGB")) # sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' # SAM Model sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' # MobileSAM # sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() #SAM loading sam = sam_model_registry[sam_type](checkpoint=sam_ckpt) #SAM loading # sam = sam_model_registry[sam_type](checkpoint=sam_ckpt) # MObileSAM loading predictor = SamPredictor(sam) # Image features encoding ref_mask = predictor.set_image(ic_image, ic_mask) ref_feat = predictor.features.squeeze().permute(1, 2, 0) ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear") ref_mask = ref_mask.squeeze()[0] # Target feature extraction print("======> Obtain Location Prior" ) target_feat = ref_feat[ref_mask > 0] target_embedding = target_feat.mean(0).unsqueeze(0) target_feat = target_embedding / target_embedding.norm(dim=-1, keepdim=True) target_embedding = target_embedding.unsqueeze(0) test_image = ic_image outputs = [] print("======> Testing Image") # Image feature encoding predictor.set_image(test_image) test_feat = predictor.features.squeeze() # Cosine similarity C, h, w = test_feat.shape test_feat = test_feat / test_feat.norm(dim=0, keepdim=True) test_feat = test_feat.reshape(C, h * w) sim = target_feat @ test_feat sim = sim.reshape(1, 1, h, w) sim = F.interpolate(sim, scale_factor=4, mode="bilinear") sim = predictor.model.postprocess_masks( sim, input_size=predictor.input_size, original_size=predictor.original_size).squeeze() # Positive-negative location prior topk_xy_i, topk_label_i, last_xy_i, last_label_i = point_selection(sim, topk=1) topk_xy = np.concatenate([topk_xy_i, last_xy_i], axis=0) topk_label = np.concatenate([topk_label_i, last_label_i], axis=0) # Obtain the target guidance for cross-attention layers sim = (sim - sim.mean()) / torch.std(sim) sim = F.interpolate(sim.unsqueeze(0).unsqueeze(0), size=(64, 64), mode="bilinear") attn_sim = sim.sigmoid_().unsqueeze(0).flatten(3) # First-step prediction masks, scores, logits, _ = predictor.predict( point_coords=topk_xy, point_labels=topk_label, multimask_output=True, attn_sim=attn_sim, # Target-guided Attention target_embedding=target_embedding # Target-semantic Prompting ) best_idx = 0 # Cascaded Post-refinement-1 masks, scores, logits, _ = predictor.predict( point_coords=topk_xy, point_labels=topk_label, mask_input=logits[best_idx: best_idx + 1, :, :], multimask_output=True) best_idx = np.argmax(scores) # Cascaded Post-refinement-2 y, x = np.nonzero(masks[best_idx]) x_min = x.min() x_max = x.max() y_min = y.min() y_max = y.max() input_box = np.array([x_min, y_min, x_max, y_max]) masks, scores, logits, _ = predictor.predict( point_coords=topk_xy, point_labels=topk_label, box=input_box[None, :], mask_input=logits[best_idx: best_idx + 1, :, :], multimask_output=True) best_idx = np.argmax(scores) final_mask = masks[best_idx] mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8) mask_colors[final_mask, :] = np.array([[128, 0, 0]]) # Save annotations return [Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB'), Image.fromarray((mask_colors ).astype('uint8'), 'RGB')] main_scribble = gr.Interface( fn=inference_scribble, inputs= gr.ImageMask(label="[Stroke] Draw on Image", type='pil'), outputs=[ gr.outputs.Image(type="pil", label="Mask with Image"), gr.outputs.Image(type="pil", label="Mask") ], allow_flagging="never", title="SAM based Segment Annotator.", description='Sketch the portion where you want to create Mask.', examples=[ "./cardamage_example/0006.JPEG", "./cardamage_example/0008.JPEG", "./cardamage_example/0206.JPEG" ] ) main_scribble.launch(share=True)