|
import os |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
import gradio as gr |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from mobile_sam import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry |
|
from PIL import ImageDraw,Image |
|
from utils.tools import box_prompt, format_results, point_prompt |
|
from utils.tools_gradio import fast_process |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
sam_checkpoint = r"F:\zht\code\MobileSAM-master\weights\mobile_sam.pt" |
|
model_type = "vit_t" |
|
mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
|
mobile_sam = mobile_sam.to(device=device) |
|
mobile_sam.eval() |
|
|
|
mask_generator = SamAutomaticMaskGenerator(mobile_sam) |
|
predictor = SamPredictor(mobile_sam) |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def segment_with_boxs( |
|
image, |
|
input_size=1024, |
|
better_quality=False, |
|
withContours=True, |
|
use_retina=True, |
|
mask_random_color=True, |
|
): |
|
global global_points |
|
global global_point_label |
|
|
|
input_size = int(input_size) |
|
w, h = image.size |
|
scale = input_size / max(w, h) |
|
new_w = int(w * scale) |
|
new_h = int(h * scale) |
|
|
|
image = image.resize((new_w, new_h)) |
|
|
|
scaled_points = np.array( |
|
[[int(x * scale) for x in point] for point in global_points] |
|
) |
|
print("nnnnnnnnnnnnnnnnnnnnnnnnnnnnn00nnnnn",scaled_points) |
|
scaled_point_label = np.array(global_point_label) |
|
|
|
nd_image = np.array(image) |
|
print("mmmmmmm0mmmm",nd_image.shape) |
|
predictor.set_image(nd_image) |
|
masks, scores, logits = predictor.predict( |
|
point_coords=scaled_points, |
|
point_labels=scaled_point_label, |
|
multimask_output=True, |
|
) |
|
|
|
results = format_results(masks, scores, logits, 0) |
|
print("mmmmmmmmmmmmmmmm2222m",len(results)) |
|
annotations, _ = point_prompt( |
|
results, scaled_points, scaled_point_label, new_h, new_w |
|
) |
|
annotations = np.array([annotations]) |
|
|
|
plt.imshow(annotations[0], cmap='viridis') |
|
plt.colorbar() |
|
plt.savefig(r'F:\zht\code\2.png') |
|
plt.show() |
|
|
|
fig = fast_process( |
|
annotations=annotations, |
|
image=image, |
|
device=device, |
|
scale=(1024 // input_size), |
|
better_quality=better_quality, |
|
mask_random_color=mask_random_color, |
|
bbox=None, |
|
use_retina=use_retina, |
|
withContours=withContours, |
|
) |
|
global_points = [] |
|
global_point_label = [] |
|
return fig, image |
|
|
|
|
|
if __name__ == "__main__": |
|
path = r"F:\zht\code\MobileSAM-master\app\assets\05.jpg" |
|
image1 = Image.open(path) |
|
|
|
print(image1.size) |
|
|
|
global_points = [[324,740,1448,1192]] |
|
global_point_label = [1] |
|
segment_with_boxs( |
|
image1, |
|
input_size=1024, |
|
better_quality=False, |
|
withContours=True, |
|
use_retina=True, |
|
mask_random_color=True, |
|
) |