File size: 8,335 Bytes
68d34d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac510cd
68d34d0
 
ac510cd
68d34d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
from segment_anything import SamPredictor, sam_model_registry,SamAutomaticMaskGenerator
from PIL import Image
import torch
from detectron2.data.detection_utils import read_image,pil_image_to_numpy
from detectron2.utils.visualizer import Visualizer
import numpy as np
from skimage import measure
import threading

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def seg_with_promp(imput_image,point_coords=None,box=None):
    if isinstance(imput_image, Image.Image):
        imput_image = pil_image_to_numpy(imput_image)
    point_labels = None
    if point_coords is not None:
        point_labels = np.ones(point_coords.shape[0])
    sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth").to(device)
    predictor = SamPredictor(sam)
    predictor.set_image(imput_image)

    masks = None

    if box is not None:
        masks, _, _ = predictor.predict(box=box)
    elif point_coords is not None and point_labels is not None:
        masks, _, _ = predictor.predict(point_coords=point_coords,point_labels=point_labels)
    print("seg_with_promp:",masks.shape)
    pil_images = draw_bitmask(imput_image,masks)
    return masks,pil_images

def seg_all(imput_image):
    if isinstance(imput_image, Image.Image):
        imput_image = pil_image_to_numpy(imput_image)

    sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(imput_image)
    pil_images = draw_bitmask(imput_image,masks)
    # pil_images = draw_polygon(imput_image,masks)
    # pil_images = draw_bitmask_split(imput_image,masks)
    return masks,pil_images

# 为每个二值掩码生成一张图片
def draw_bitmask_split(np_image,masks):
    for i,obj in enumerate(masks):
        print("segmentation:",obj["segmentation"].shape)
        view = Visualizer(np_image)
        view.draw_binary_mask(obj["segmentation"])
        vis_image = view.get_output()
        pil_images = visimage_to_pil([vis_image],idx=i)
    return pil_images

# 绘制二值掩码
def draw_bitmask(np_image,masks):
    view = Visualizer(np_image)
    for obj in masks:
        if "segmentation" in obj:
            print("segmentation:",obj["segmentation"].shape)
            view.draw_binary_mask(obj["segmentation"])
        else:
            view.draw_binary_mask(obj)
        
    vis_image = view.get_output()
    pil_images = visimage_to_pil([vis_image])
    return pil_images

# 绘制多边形掩码
def draw_polygon(np_image,masks):
    view = Visualizer(np_image)
    for obj in masks:
        polygon = bitmask_to_polygon(obj["segmentation"])
        view.draw_polygon(polygon,"k")
    vis_image = view.get_output()
    pil_images = visimage_to_pil([vis_image])
    return pil_images


# 二值掩码转换为多边形掩码
def bitmask_to_polygon(mask):
    col_mask = np.asfortranarray(mask)
    contours = measure.find_contours(col_mask,0.5)
    print("contours------",contours.shape)
    for i,contour in enumerate(contours):
        contour = np.flip(contour, axis=1)
        print(f"polygon_{i}",contour.shape)
        # polygon = contour.ravel().tolist()
        # print(f"polygon_{i}",polygon)
    return contour

# VIS图片转换为pil
def visimage_to_pil(visimages,need_save=False,idx=0):
    pil_images = []
    for i,visimage in enumerate(visimages):
        visualized_image = visimage.get_image()[:, :, ::-1]
        pil_image = Image.fromarray(visualized_image)
        if need_save:
            pil_image.save(f"{idx}_{i}.jpg")
        pil_images.append(pil_image)
    return pil_images

def image_to_mask(image, threshold=128):
    # 将图像转换为灰度图像
    if image.mode != 'L':
        image = image.convert('L')
    
    # 将像素值映射到二进制值
    mask_array = np.array(image) > threshold
    
    # 创建一个与原始图像大小相同的数组,用映射后的二进制值填充
    mask_image = Image.fromarray(np.uint8(mask_array) * 255)
    
    return mask_image
class SamAnything:
    _instance = None
    _lock = threading.Lock()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = super(SamAnything, cls).__new__(cls)
                    cls._instance._initialize(*args, **kwargs)
        return cls._instance

    def _initialize(self, checkpoint_path="./sam_vit_b_01ec64.pth"):
        self.sam = sam_model_registry["vit_b"](checkpoint=checkpoint_path).to(self.device)
        self.predictor = SamPredictor(self.sam)
        self.mask_generator = SamAutomaticMaskGenerator(self.sam)

    def seg_with_promp(self, imput_image, point_coords=None, box=None):
        if isinstance(imput_image, Image.Image):
            imput_image = pil_image_to_numpy(imput_image)
        point_labels = None
        if point_coords is not None:
            point_labels = np.ones(point_coords.shape[0])

        self.predictor.set_image(imput_image)
        masks = None

        if box is not None:
            masks, _, _ = self.predictor.predict(box=box)
        elif point_coords is not None and point_labels is not None:
            masks, _, _ = self.predictor.predict(point_coords=point_coords, point_labels=point_labels)

        print("seg_with_promp:", masks.shape)
        pil_images = self.draw_bitmask(imput_image, masks)
        return masks, pil_images

    def seg_all(self, imput_image):
        if isinstance(imput_image, Image.Image):
            imput_image = pil_image_to_numpy(imput_image)

        masks = self.mask_generator.generate(imput_image)
        pil_images = self.draw_bitmask(imput_image, masks)
        return masks, pil_images

    @staticmethod
    def draw_bitmask_split(np_image, masks):
        pil_images = []
        for i, obj in enumerate(masks):
            print("segmentation:", obj["segmentation"].shape)
            view = Visualizer(np_image)
            view.draw_binary_mask(obj["segmentation"])
            vis_image = view.get_output()
            pil_images.extend(SamAnything.visimage_to_pil([vis_image], idx=i))
        return pil_images

    @staticmethod
    def draw_bitmask(np_image, masks):
        view = Visualizer(np_image)
        for obj in masks:
            if "segmentation" in obj:
                print("segmentation:", obj["segmentation"].shape)
                view.draw_binary_mask(obj["segmentation"])
            else:
                view.draw_binary_mask(obj)
        
        vis_image = view.get_output()
        pil_images = SamAnything.visimage_to_pil([vis_image])
        return pil_images

    @staticmethod
    def draw_polygon(np_image, masks):
        view = Visualizer(np_image)
        for obj in masks:
            polygon = SamAnything.bitmask_to_polygon(obj["segmentation"])
            view.draw_polygon(polygon, "k")
        vis_image = view.get_output()
        pil_images = SamAnything.visimage_to_pil([vis_image])
        return pil_images

    @staticmethod
    def bitmask_to_polygon(mask):
        col_mask = np.asfortranarray(mask)
        contours = measure.find_contours(col_mask, 0.5)
        print("contours------", len(contours))
        for i, contour in enumerate(contours):
            contour = np.flip(contour, axis=1)
            print(f"polygon_{i}", contour.shape)
        return contours

    @staticmethod
    def visimage_to_pil(visimages, need_save=True, idx=0):
        pil_images = []
        for i, visimage in enumerate(visimages):
            visualized_image = visimage.get_image()
            pil_image = Image.fromarray(visualized_image)
            if need_save:
                pil_image.save(f"{idx}_{i}.jpg")
            pil_images.append(pil_image)
        return pil_images

    @staticmethod
    def image_to_mask(image, threshold=128):
        if image.mode != 'L':
            image = image.convert('L')
        mask_array = np.array(image) > threshold
        mask_image = Image.fromarray(np.uint8(mask_array) * 255)
        return mask_image

# if __name__ == "__main__":
#     np_image = read_image("./test/face1.jpeg")
#     print("np_image:",np_image.shape)
#     SamAnything.initialize_sam("./sam_vit_b_01ec64.pth")
#     SamAnything.seg_all(np_image)