Spaces:
Sleeping
Sleeping
from segment_anything import SamPredictor, sam_model_registry,SamAutomaticMaskGenerator | |
from PIL import Image | |
import torch | |
from detectron2.data.detection_utils import read_image,pil_image_to_numpy | |
from detectron2.utils.visualizer import Visualizer | |
import numpy as np | |
from skimage import measure | |
import threading | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def seg_with_promp(imput_image,point_coords=None,box=None): | |
if isinstance(imput_image, Image.Image): | |
imput_image = pil_image_to_numpy(imput_image) | |
point_labels = None | |
if point_coords is not None: | |
point_labels = np.ones(point_coords.shape[0]) | |
sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth").to(device) | |
predictor = SamPredictor(sam) | |
predictor.set_image(imput_image) | |
masks = None | |
if box is not None: | |
masks, _, _ = predictor.predict(box=box) | |
elif point_coords is not None and point_labels is not None: | |
masks, _, _ = predictor.predict(point_coords=point_coords,point_labels=point_labels) | |
print("seg_with_promp:",masks.shape) | |
pil_images = draw_bitmask(imput_image,masks) | |
return masks,pil_images | |
def seg_all(imput_image): | |
if isinstance(imput_image, Image.Image): | |
imput_image = pil_image_to_numpy(imput_image) | |
sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth") | |
mask_generator = SamAutomaticMaskGenerator(sam) | |
masks = mask_generator.generate(imput_image) | |
pil_images = draw_bitmask(imput_image,masks) | |
# pil_images = draw_polygon(imput_image,masks) | |
# pil_images = draw_bitmask_split(imput_image,masks) | |
return masks,pil_images | |
# 为每个二值掩码生成一张图片 | |
def draw_bitmask_split(np_image,masks): | |
for i,obj in enumerate(masks): | |
print("segmentation:",obj["segmentation"].shape) | |
view = Visualizer(np_image) | |
view.draw_binary_mask(obj["segmentation"]) | |
vis_image = view.get_output() | |
pil_images = visimage_to_pil([vis_image],idx=i) | |
return pil_images | |
# 绘制二值掩码 | |
def draw_bitmask(np_image,masks): | |
view = Visualizer(np_image) | |
for obj in masks: | |
if "segmentation" in obj: | |
print("segmentation:",obj["segmentation"].shape) | |
view.draw_binary_mask(obj["segmentation"]) | |
else: | |
view.draw_binary_mask(obj) | |
vis_image = view.get_output() | |
pil_images = visimage_to_pil([vis_image]) | |
return pil_images | |
# 绘制多边形掩码 | |
def draw_polygon(np_image,masks): | |
view = Visualizer(np_image) | |
for obj in masks: | |
polygon = bitmask_to_polygon(obj["segmentation"]) | |
view.draw_polygon(polygon,"k") | |
vis_image = view.get_output() | |
pil_images = visimage_to_pil([vis_image]) | |
return pil_images | |
# 二值掩码转换为多边形掩码 | |
def bitmask_to_polygon(mask): | |
col_mask = np.asfortranarray(mask) | |
contours = measure.find_contours(col_mask,0.5) | |
print("contours------",contours.shape) | |
for i,contour in enumerate(contours): | |
contour = np.flip(contour, axis=1) | |
print(f"polygon_{i}",contour.shape) | |
# polygon = contour.ravel().tolist() | |
# print(f"polygon_{i}",polygon) | |
return contour | |
# VIS图片转换为pil | |
def visimage_to_pil(visimages,need_save=False,idx=0): | |
pil_images = [] | |
for i,visimage in enumerate(visimages): | |
visualized_image = visimage.get_image()[:, :, ::-1] | |
pil_image = Image.fromarray(visualized_image) | |
if need_save: | |
pil_image.save(f"{idx}_{i}.jpg") | |
pil_images.append(pil_image) | |
return pil_images | |
def image_to_mask(image, threshold=128): | |
# 将图像转换为灰度图像 | |
if image.mode != 'L': | |
image = image.convert('L') | |
# 将像素值映射到二进制值 | |
mask_array = np.array(image) > threshold | |
# 创建一个与原始图像大小相同的数组,用映射后的二进制值填充 | |
mask_image = Image.fromarray(np.uint8(mask_array) * 255) | |
return mask_image | |
class SamAnything: | |
_instance = None | |
_lock = threading.Lock() | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def __new__(cls, *args, **kwargs): | |
if cls._instance is None: | |
with cls._lock: | |
if cls._instance is None: | |
cls._instance = super(SamAnything, cls).__new__(cls) | |
cls._instance._initialize(*args, **kwargs) | |
return cls._instance | |
def _initialize(self, checkpoint_path="./sam_vit_b_01ec64.pth"): | |
self.sam = sam_model_registry["vit_b"](checkpoint=checkpoint_path).to(self.device) | |
self.predictor = SamPredictor(self.sam) | |
self.mask_generator = SamAutomaticMaskGenerator(self.sam) | |
def seg_with_promp(self, imput_image, point_coords=None, box=None): | |
if isinstance(imput_image, Image.Image): | |
imput_image = pil_image_to_numpy(imput_image) | |
point_labels = None | |
if point_coords is not None: | |
point_labels = np.ones(point_coords.shape[0]) | |
self.predictor.set_image(imput_image) | |
masks = None | |
if box is not None: | |
masks, _, _ = self.predictor.predict(box=box) | |
elif point_coords is not None and point_labels is not None: | |
masks, _, _ = self.predictor.predict(point_coords=point_coords, point_labels=point_labels) | |
print("seg_with_promp:", masks.shape) | |
pil_images = self.draw_bitmask(imput_image, masks) | |
return masks, pil_images | |
def seg_all(self, imput_image): | |
if isinstance(imput_image, Image.Image): | |
imput_image = pil_image_to_numpy(imput_image) | |
masks = self.mask_generator.generate(imput_image) | |
pil_images = self.draw_bitmask(imput_image, masks) | |
return masks, pil_images | |
def draw_bitmask_split(np_image, masks): | |
pil_images = [] | |
for i, obj in enumerate(masks): | |
print("segmentation:", obj["segmentation"].shape) | |
view = Visualizer(np_image) | |
view.draw_binary_mask(obj["segmentation"]) | |
vis_image = view.get_output() | |
pil_images.extend(SamAnything.visimage_to_pil([vis_image], idx=i)) | |
return pil_images | |
def draw_bitmask(np_image, masks): | |
view = Visualizer(np_image) | |
for obj in masks: | |
if "segmentation" in obj: | |
print("segmentation:", obj["segmentation"].shape) | |
view.draw_binary_mask(obj["segmentation"]) | |
else: | |
view.draw_binary_mask(obj) | |
vis_image = view.get_output() | |
pil_images = SamAnything.visimage_to_pil([vis_image]) | |
return pil_images | |
def draw_polygon(np_image, masks): | |
view = Visualizer(np_image) | |
for obj in masks: | |
polygon = SamAnything.bitmask_to_polygon(obj["segmentation"]) | |
view.draw_polygon(polygon, "k") | |
vis_image = view.get_output() | |
pil_images = SamAnything.visimage_to_pil([vis_image]) | |
return pil_images | |
def bitmask_to_polygon(mask): | |
col_mask = np.asfortranarray(mask) | |
contours = measure.find_contours(col_mask, 0.5) | |
print("contours------", len(contours)) | |
for i, contour in enumerate(contours): | |
contour = np.flip(contour, axis=1) | |
print(f"polygon_{i}", contour.shape) | |
return contours | |
def visimage_to_pil(visimages, need_save=True, idx=0): | |
pil_images = [] | |
for i, visimage in enumerate(visimages): | |
visualized_image = visimage.get_image() | |
pil_image = Image.fromarray(visualized_image) | |
if need_save: | |
pil_image.save(f"{idx}_{i}.jpg") | |
pil_images.append(pil_image) | |
return pil_images | |
def image_to_mask(image, threshold=128): | |
if image.mode != 'L': | |
image = image.convert('L') | |
mask_array = np.array(image) > threshold | |
mask_image = Image.fromarray(np.uint8(mask_array) * 255) | |
return mask_image | |
# if __name__ == "__main__": | |
# np_image = read_image("./test/face1.jpeg") | |
# print("np_image:",np_image.shape) | |
# SamAnything.initialize_sam("./sam_vit_b_01ec64.pth") | |
# SamAnything.seg_all(np_image) |