test2 / utils /test5.py
jianyouli
Add application file1
b759a29
raw
history blame
3.25 kB
import os
# import matplotlib
# matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt
import gradio as gr
import cv2
import numpy as np
import torch
from mobile_sam import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from PIL import ImageDraw,Image
from utils.tools import box_prompt, format_results, point_prompt
from utils.tools_gradio import fast_process
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the pre-trained model
sam_checkpoint = r"F:\zht\code\MobileSAM-master\weights\mobile_sam.pt"
model_type = "vit_t"
mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
mobile_sam = mobile_sam.to(device=device)
mobile_sam.eval()
mask_generator = SamAutomaticMaskGenerator(mobile_sam)
predictor = SamPredictor(mobile_sam)
# default_example = examples[0]
@torch.no_grad()
def segment_with_boxs(
image,
input_size=1024,
better_quality=False,
withContours=True,
use_retina=True,
mask_random_color=True,
):
global global_points
global global_point_label
input_size = int(input_size)
w, h = image.size
scale = input_size / max(w, h)
new_w = int(w * scale)
new_h = int(h * scale)
image = image.resize((new_w, new_h))
#################
scaled_points = np.array(
[[int(x * scale) for x in point] for point in global_points]
)
print("nnnnnnnnnnnnnnnnnnnnnnnnnnnnn00nnnnn",scaled_points)
scaled_point_label = np.array(global_point_label)
nd_image = np.array(image)
print("mmmmmmm0mmmm",nd_image.shape) #(685, 1024, 3)
predictor.set_image(nd_image) #改变形状
masks, scores, logits = predictor.predict(
point_coords=scaled_points,
point_labels=scaled_point_label,
multimask_output=True,
)
results = format_results(masks, scores, logits, 0)
print("mmmmmmmmmmmmmmmm2222m",len(results)) # [530 437]
annotations, _ = point_prompt(
results, scaled_points, scaled_point_label, new_h, new_w
)
annotations = np.array([annotations])
# 显示图像
plt.imshow(annotations[0], cmap='viridis') # 使用 'viridis' 颜色映射
plt.colorbar() # 显示颜色条
plt.savefig(r'F:\zht\code\2.png')
plt.show()
fig = fast_process(
annotations=annotations,
image=image,
device=device,
scale=(1024 // input_size),
better_quality=better_quality,
mask_random_color=mask_random_color,
bbox=None,
use_retina=use_retina,
withContours=withContours,
)
global_points = []
global_point_label = []
return fig, image
#################################################
if __name__ == "__main__":
path = r"F:\zht\code\MobileSAM-master\app\assets\05.jpg"
image1 = Image.open(path)
# image = cv2.imread(path)
print(image1.size)
# global_points = [[1069,928]]
global_points = [[324,740,1448,1192]]
global_point_label = [1]
segment_with_boxs(
image1,
input_size=1024,
better_quality=False,
withContours=True,
use_retina=True,
mask_random_color=True,
)