ai / sam_everything.py
neoguojing
finish ocr
ac510cd
from segment_anything import SamPredictor, sam_model_registry,SamAutomaticMaskGenerator
from PIL import Image
import torch
from detectron2.data.detection_utils import read_image,pil_image_to_numpy
from detectron2.utils.visualizer import Visualizer
import numpy as np
from skimage import measure
import threading
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def seg_with_promp(imput_image,point_coords=None,box=None):
if isinstance(imput_image, Image.Image):
imput_image = pil_image_to_numpy(imput_image)
point_labels = None
if point_coords is not None:
point_labels = np.ones(point_coords.shape[0])
sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth").to(device)
predictor = SamPredictor(sam)
predictor.set_image(imput_image)
masks = None
if box is not None:
masks, _, _ = predictor.predict(box=box)
elif point_coords is not None and point_labels is not None:
masks, _, _ = predictor.predict(point_coords=point_coords,point_labels=point_labels)
print("seg_with_promp:",masks.shape)
pil_images = draw_bitmask(imput_image,masks)
return masks,pil_images
def seg_all(imput_image):
if isinstance(imput_image, Image.Image):
imput_image = pil_image_to_numpy(imput_image)
sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(imput_image)
pil_images = draw_bitmask(imput_image,masks)
# pil_images = draw_polygon(imput_image,masks)
# pil_images = draw_bitmask_split(imput_image,masks)
return masks,pil_images
# 为每个二值掩码生成一张图片
def draw_bitmask_split(np_image,masks):
for i,obj in enumerate(masks):
print("segmentation:",obj["segmentation"].shape)
view = Visualizer(np_image)
view.draw_binary_mask(obj["segmentation"])
vis_image = view.get_output()
pil_images = visimage_to_pil([vis_image],idx=i)
return pil_images
# 绘制二值掩码
def draw_bitmask(np_image,masks):
view = Visualizer(np_image)
for obj in masks:
if "segmentation" in obj:
print("segmentation:",obj["segmentation"].shape)
view.draw_binary_mask(obj["segmentation"])
else:
view.draw_binary_mask(obj)
vis_image = view.get_output()
pil_images = visimage_to_pil([vis_image])
return pil_images
# 绘制多边形掩码
def draw_polygon(np_image,masks):
view = Visualizer(np_image)
for obj in masks:
polygon = bitmask_to_polygon(obj["segmentation"])
view.draw_polygon(polygon,"k")
vis_image = view.get_output()
pil_images = visimage_to_pil([vis_image])
return pil_images
# 二值掩码转换为多边形掩码
def bitmask_to_polygon(mask):
col_mask = np.asfortranarray(mask)
contours = measure.find_contours(col_mask,0.5)
print("contours------",contours.shape)
for i,contour in enumerate(contours):
contour = np.flip(contour, axis=1)
print(f"polygon_{i}",contour.shape)
# polygon = contour.ravel().tolist()
# print(f"polygon_{i}",polygon)
return contour
# VIS图片转换为pil
def visimage_to_pil(visimages,need_save=False,idx=0):
pil_images = []
for i,visimage in enumerate(visimages):
visualized_image = visimage.get_image()[:, :, ::-1]
pil_image = Image.fromarray(visualized_image)
if need_save:
pil_image.save(f"{idx}_{i}.jpg")
pil_images.append(pil_image)
return pil_images
def image_to_mask(image, threshold=128):
# 将图像转换为灰度图像
if image.mode != 'L':
image = image.convert('L')
# 将像素值映射到二进制值
mask_array = np.array(image) > threshold
# 创建一个与原始图像大小相同的数组,用映射后的二进制值填充
mask_image = Image.fromarray(np.uint8(mask_array) * 255)
return mask_image
class SamAnything:
_instance = None
_lock = threading.Lock()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def __new__(cls, *args, **kwargs):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(SamAnything, cls).__new__(cls)
cls._instance._initialize(*args, **kwargs)
return cls._instance
def _initialize(self, checkpoint_path="./sam_vit_b_01ec64.pth"):
self.sam = sam_model_registry["vit_b"](checkpoint=checkpoint_path).to(self.device)
self.predictor = SamPredictor(self.sam)
self.mask_generator = SamAutomaticMaskGenerator(self.sam)
def seg_with_promp(self, imput_image, point_coords=None, box=None):
if isinstance(imput_image, Image.Image):
imput_image = pil_image_to_numpy(imput_image)
point_labels = None
if point_coords is not None:
point_labels = np.ones(point_coords.shape[0])
self.predictor.set_image(imput_image)
masks = None
if box is not None:
masks, _, _ = self.predictor.predict(box=box)
elif point_coords is not None and point_labels is not None:
masks, _, _ = self.predictor.predict(point_coords=point_coords, point_labels=point_labels)
print("seg_with_promp:", masks.shape)
pil_images = self.draw_bitmask(imput_image, masks)
return masks, pil_images
def seg_all(self, imput_image):
if isinstance(imput_image, Image.Image):
imput_image = pil_image_to_numpy(imput_image)
masks = self.mask_generator.generate(imput_image)
pil_images = self.draw_bitmask(imput_image, masks)
return masks, pil_images
@staticmethod
def draw_bitmask_split(np_image, masks):
pil_images = []
for i, obj in enumerate(masks):
print("segmentation:", obj["segmentation"].shape)
view = Visualizer(np_image)
view.draw_binary_mask(obj["segmentation"])
vis_image = view.get_output()
pil_images.extend(SamAnything.visimage_to_pil([vis_image], idx=i))
return pil_images
@staticmethod
def draw_bitmask(np_image, masks):
view = Visualizer(np_image)
for obj in masks:
if "segmentation" in obj:
print("segmentation:", obj["segmentation"].shape)
view.draw_binary_mask(obj["segmentation"])
else:
view.draw_binary_mask(obj)
vis_image = view.get_output()
pil_images = SamAnything.visimage_to_pil([vis_image])
return pil_images
@staticmethod
def draw_polygon(np_image, masks):
view = Visualizer(np_image)
for obj in masks:
polygon = SamAnything.bitmask_to_polygon(obj["segmentation"])
view.draw_polygon(polygon, "k")
vis_image = view.get_output()
pil_images = SamAnything.visimage_to_pil([vis_image])
return pil_images
@staticmethod
def bitmask_to_polygon(mask):
col_mask = np.asfortranarray(mask)
contours = measure.find_contours(col_mask, 0.5)
print("contours------", len(contours))
for i, contour in enumerate(contours):
contour = np.flip(contour, axis=1)
print(f"polygon_{i}", contour.shape)
return contours
@staticmethod
def visimage_to_pil(visimages, need_save=True, idx=0):
pil_images = []
for i, visimage in enumerate(visimages):
visualized_image = visimage.get_image()
pil_image = Image.fromarray(visualized_image)
if need_save:
pil_image.save(f"{idx}_{i}.jpg")
pil_images.append(pil_image)
return pil_images
@staticmethod
def image_to_mask(image, threshold=128):
if image.mode != 'L':
image = image.convert('L')
mask_array = np.array(image) > threshold
mask_image = Image.fromarray(np.uint8(mask_array) * 255)
return mask_image
# if __name__ == "__main__":
# np_image = read_image("./test/face1.jpeg")
# print("np_image:",np_image.shape)
# SamAnything.initialize_sam("./sam_vit_b_01ec64.pth")
# SamAnything.seg_all(np_image)