File size: 5,621 Bytes
c426a27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import time
import torch
import cv2
from PIL import Image, ImageDraw, ImageOps
import numpy as np
from typing import Union
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
import matplotlib.pyplot as plt
import PIL

class BaseSegmenter:
    def __init__(self, device, checkpoint, model_type='vit_h', reuse_feature = True):
        print(f"Initializing BaseSegmenter to {device}")
        self.device = device
        self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
        self.processor = None
        self.model_type = model_type
        self.checkpoint = checkpoint
        self.model = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
        self.model.to(device=self.device)
        self.reuse_feature = reuse_feature
        self.predictor = SamPredictor(self.model)
        self.mask_generator = SamAutomaticMaskGenerator(self.model)
        self.image_embedding = None
        self.image = None

    
    @torch.no_grad()
    def set_image(self, image: Union[np.ndarray, Image.Image, str]):
        if type(image) == str: # input path
            image = Image.open(image)
            image = np.array(image)
        elif type(image) == Image.Image:
            image = np.array(image)
        self.image = image
        if self.reuse_feature:
            self.predictor.set_image(image)
            self.image_embedding = self.predictor.get_image_embedding()
            print(self.image_embedding.shape)

    
    @torch.no_grad()
    def inference(self, image, control):
        if 'everything' in control['prompt_type']:
            masks = self.mask_generator.generate(image)
            new_masks = np.concatenate([mask["segmentation"][np.newaxis,:] for mask in masks])
            return new_masks
        else:
            if not self.reuse_feature:
                self.set_image(image)
                self.predictor.set_image(self.image)
            else:
                assert self.image_embedding is not None
                self.predictor.features = self.image_embedding
      
        if 'mutimask_output' in control:
            masks, scores, logits = self.predictor.predict(
                point_coords = np.array(control['input_point']),
                point_labels = np.array(control['input_label']),
                multimask_output = True,
            )
        elif 'input_boxes' in control:
            transformed_boxes = self.predictor.transform.apply_boxes_torch(
                torch.tensor(control["input_boxes"], device=self.predictor.device),
                image.shape[:2]
            )
            masks, _, _ = self.predictor.predict_torch(
                point_coords=None,
                point_labels=None,
                boxes=transformed_boxes,
                multimask_output=False,
            )
            masks = masks.squeeze(1).cpu().numpy()
            
        else:
            input_point = np.array(control['input_point']) if 'click' in control['prompt_type'] else None
            input_label = np.array(control['input_label']) if 'click' in control['prompt_type'] else None
            input_box = np.array(control['input_box']) if 'box' in control['prompt_type'] else None
           
            masks, scores, logits = self.predictor.predict(
                point_coords = input_point,
                point_labels = input_label,
                box = input_box,
                multimask_output = False,
            )
            
            if 0 in control['input_label']:
                mask_input = logits[np.argmax(scores), :, :]
                masks, scores, logits = self.predictor.predict(
                    point_coords=input_point,
                    point_labels=input_label,
                    box = input_box,
                    mask_input=mask_input[None, :, :],
                    multimask_output=False,
                )
  
        return masks

if __name__ == "__main__":
    image_path = 'segmenter/images/truck.jpg'
    prompts = [
        # {
        #     "prompt_type":["click"],
        #     "input_point":[[500, 375]],
        #     "input_label":[1],
        #     "multimask_output":"True",
        # },
        {
            "prompt_type":["click"],
            "input_point":[[1000, 600], [1325, 625]],
            "input_label":[1, 0],
        },
        # {
        #     "prompt_type":["click", "box"],
        #     "input_box":[425, 600, 700, 875],
        #     "input_point":[[575, 750]],
        #     "input_label": [0]
        # },
        # {
        #     "prompt_type":["box"],
        #     "input_boxes": [
        #         [75, 275, 1725, 850],
        #         [425, 600, 700, 875],
        #         [1375, 550, 1650, 800],
        #         [1240, 675, 1400, 750],
        #     ]
        # },
        # {
        #     "prompt_type":["everything"]
        # },
    ]
    
    init_time = time.time()
    segmenter = BaseSegmenter(
        device='cuda',
        # checkpoint='sam_vit_h_4b8939.pth',
        checkpoint='segmenter/sam_vit_h_4b8939.pth',
        model_type='vit_h',
        reuse_feature=True
    )
    print(f'init time: {time.time() - init_time}')
    
    image_path = 'test_img/img2.jpg'
    infer_time = time.time()
    for i, prompt in enumerate(prompts):
        print(f'{prompt["prompt_type"]} mode')
        image = Image.open(image_path)
        segmenter.set_image(np.array(image))
        masks = segmenter.inference(np.array(image), prompt)
        Image.fromarray(masks[0]).save('seg.png')
        print(masks.shape)
        
    print(f'infer time: {time.time() - infer_time}')