File size: 4,723 Bytes
8fc40b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from .model import FastSAM
import numpy as np
from PIL import Image
import clip
from typing import Optional, List, Tuple, Union


class FastSAMDecoder:
    def __init__(
        self,
        model: FastSAM,
        device: str='cpu',
        conf: float=0.4, 
        iou: float=0.9,
        imgsz: int=1024,
        retina_masks: bool=True,
        ):
        self.model = model
        self.device = device
        self.retina_masks = retina_masks
        self.imgsz = imgsz
        self.conf = conf
        self.iou = iou
        self.image = None
        self.image_embedding = None
        
    def run_encoder(self, image):
        if isinstance(image,str):
            image =  np.array(Image.open(image))
        self.image = image
        image_embedding = self.model(
            self.image, 
            device=self.device, 
            retina_masks=self.retina_masks, 
            imgsz=self.imgsz, 
            conf=self.conf, 
            iou=self.iou
            )
        return image_embedding[0].numpy()

    def run_decoder(
            self,
            image_embedding, 
            point_prompt: Optional[np.ndarray]=None,
            point_label: Optional[np.ndarray]=None,
            box_prompt: Optional[np.ndarray]=None,
            text_prompt: Optional[str]=None,
            )->np.ndarray:
        self.image_embedding = image_embedding
        if point_prompt is not None:
            ann = self.point_prompt(points=point_prompt, pointlabel=point_label)
            return ann
        elif box_prompt is not None:
            ann = self.box_prompt(bbox=box_prompt)
            return ann
        elif text_prompt is not None:
            ann = self.text_prompt(text=text_prompt)
            return ann
        else:
            return None

    def box_prompt(self, bbox):
        assert (bbox[2] != 0 and bbox[3] != 0)
        masks = self.image_embedding.masks.data
        target_height = self.image.shape[0]
        target_width = self.image.shape[1]
        h = masks.shape[1]
        w = masks.shape[2]
        if h != target_height or w != target_width:
            bbox = [
                int(bbox[0] * w / target_width),
                int(bbox[1] * h / target_height),
                int(bbox[2] * w / target_width),
                int(bbox[3] * h / target_height), ]
        bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
        bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
        bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
        bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h

        # IoUs = torch.zeros(len(masks), dtype=torch.float32)
        bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])

        masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2))
        orig_masks_area = np.sum(masks, axis=(1, 2))

        union = bbox_area + orig_masks_area - masks_area
        IoUs = masks_area / union
        max_iou_index = np.argmax(IoUs)

        return np.array([masks[max_iou_index].cpu().numpy()])

    def point_prompt(self, points, pointlabel):  # numpy 

        masks = self._format_results(self.image_embedding[0], 0)
        target_height = self.image.shape[0]
        target_width = self.image.shape[1]
        h = masks[0]['segmentation'].shape[0]
        w = masks[0]['segmentation'].shape[1]
        if h != target_height or w != target_width:
            points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
        onemask = np.zeros((h, w))
        masks = sorted(masks, key=lambda x: x['area'], reverse=True)
        for i, annotation in enumerate(masks):
            if type(annotation) == dict:
                mask = annotation['segmentation']
            else:
                mask = annotation
            for i, point in enumerate(points):
                if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
                    onemask[mask] = 1
                if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
                    onemask[mask] = 0
        onemask = onemask >= 1
        return np.array([onemask])
    
    def _format_results(self, result, filter=0):
        annotations = []
        n = len(result.masks.data)
        for i in range(n):
            annotation = {}
            mask = result.masks.data[i] == 1.0

            if np.sum(mask) < filter:
                continue
            annotation['id'] = i
            annotation['segmentation'] = mask
            annotation['bbox'] = result.boxes.data[i]
            annotation['score'] = result.boxes.conf[i]
            annotation['area'] = annotation['segmentation'].sum()
            annotations.append(annotation)
        return annotations