import os # import matplotlib # matplotlib.use('Qt5Agg') import matplotlib.pyplot as plt import gradio as gr import cv2 import numpy as np import torch from mobile_sam import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry from PIL import ImageDraw,Image from utils.tools import box_prompt, format_results, point_prompt from utils.tools_gradio import fast_process device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the pre-trained model sam_checkpoint = r"F:\zht\code\MobileSAM-master\weights\mobile_sam.pt" model_type = "vit_t" mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) mobile_sam = mobile_sam.to(device=device) mobile_sam.eval() mask_generator = SamAutomaticMaskGenerator(mobile_sam) predictor = SamPredictor(mobile_sam) # default_example = examples[0] @torch.no_grad() def segment_with_boxs( image, input_size=1024, better_quality=False, withContours=True, use_retina=True, mask_random_color=True, ): global global_points global global_point_label input_size = int(input_size) w, h = image.size scale = input_size / max(w, h) new_w = int(w * scale) new_h = int(h * scale) image = image.resize((new_w, new_h)) ################# scaled_points = np.array( [[int(x * scale) for x in point] for point in global_points] ) print("nnnnnnnnnnnnnnnnnnnnnnnnnnnnn00nnnnn",scaled_points) scaled_point_label = np.array(global_point_label) nd_image = np.array(image) print("mmmmmmm0mmmm",nd_image.shape) #(685, 1024, 3) predictor.set_image(nd_image) #改变形状 masks, scores, logits = predictor.predict( point_coords=scaled_points, point_labels=scaled_point_label, multimask_output=True, ) results = format_results(masks, scores, logits, 0) print("mmmmmmmmmmmmmmmm2222m",len(results)) # [530 437] annotations, _ = point_prompt( results, scaled_points, scaled_point_label, new_h, new_w ) annotations = np.array([annotations]) # 显示图像 plt.imshow(annotations[0], cmap='viridis') # 使用 'viridis' 颜色映射 plt.colorbar() # 显示颜色条 plt.savefig(r'F:\zht\code\2.png') plt.show() fig = fast_process( annotations=annotations, image=image, device=device, scale=(1024 // input_size), better_quality=better_quality, mask_random_color=mask_random_color, bbox=None, use_retina=use_retina, withContours=withContours, ) global_points = [] global_point_label = [] return fig, image ################################################# if __name__ == "__main__": path = r"F:\zht\code\MobileSAM-master\app\assets\05.jpg" image1 = Image.open(path) # image = cv2.imread(path) print(image1.size) # global_points = [[1069,928]] global_points = [[324,740,1448,1192]] global_point_label = [1] segment_with_boxs( image1, input_size=1024, better_quality=False, withContours=True, use_retina=True, mask_random_color=True, )