File size: 5,574 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from third_parts.segment_anything import sam_model_registry
from third_parts.segment_anything.utils.transforms import ResizeLongestSide


def mask2box(mask):
    ys, xs = np.where(mask > 0)
    y0, y1 = ys.min(), ys.max()
    x0, x1 = xs.min(), xs.max()

    return np.array([x0, y0, x1+1, y1+1])    # avoid x0==x1


def compute_mask_IoU(masks, target):
    temp = masks * target
    intersection = temp.sum(dim=-1)
    union = ((masks + target) - temp).sum(dim=-1)
    return intersection, union, intersection / (union + 1e-12)


class SAMWrapper(nn.Module):
    def __init__(self, model_name, checkpoint,
                 use_text=True, use_mask=True, use_box=True,
                 multimask_output=False):
        super(SAMWrapper, self).__init__()
        self.model = sam_model_registry[model_name](checkpoint=checkpoint)
        self.model.image_encoder.requires_grad_(False)
        self.transform = ResizeLongestSide(self.model.image_encoder.img_size)
        self.use_text = use_text
        self.use_mask = use_mask
        self.use_box = use_box
        self.multimask_output = multimask_output

    def train(self, mode=True):
        super().train(mode=mode)
        self.model.image_encoder.eval()
        self.training = mode
        return self

    @property
    def dtype(self):
        return self.model.dtype

    @torch.no_grad()
    def encode_image(self, image):
        image = np.array(image.convert(self.model.image_format))
        input_image = self.transform.apply_image(image)
        input_image_torch = torch.as_tensor(
            input_image, device=self.model.device)
        transformed_image = input_image_torch.permute(
            2, 0, 1).contiguous()[None, :, :, :]

        original_image_size = image.shape[:2]
        input_size = transformed_image.shape[-2:]

        features = self.model.image_encoder(
            self.model.preprocess(transformed_image))

        return features, original_image_size, input_size

    def generate_prompt_masks(self, masks, input_size):
        pad_value = min(-1.0, masks.min().item())
        masks = F.interpolate(masks[:, None].float(
        ), size=input_size, mode='bilinear').to(masks)
        h, w = masks.shape[-2:]
        masks = F.pad(masks, (0, self.model.image_encoder.img_size - w,
                              0, self.model.image_encoder.img_size - h), value=pad_value)
        prompt_masks = F.interpolate(masks.float(), size=(
            256, 256), mode='bilinear').to(masks)

        return prompt_masks

    def forward(self, image, pred_masks, text_embeds):
        # masks are in logits
        image_embedding, original_image_size, input_size = self.encode_image(
            image)
        if self.training:
            image_embedding.requires_grad = True
        prompt_masks = self.generate_prompt_masks(pred_masks, input_size)

        pred_masks = F.interpolate(pred_masks.detach()[None].float().sigmoid(),
                                   size=original_image_size, mode='bilinear')[0]
        pred_masks = (pred_masks > 0.5).to(pred_masks)

        sam_masks = []
        for prompt_mask, pred_mask, text_embed in zip(prompt_masks, pred_masks, text_embeds):
            if self.use_box:
                if pred_mask.sum() > 0:
                    box = mask2box(pred_mask.float().cpu().numpy())
                else:
                    h, w = original_image_size
                    box = np.array([0.0, 0.0, w, h])
                box = self.transform.apply_boxes(box, original_image_size)
                box_torch = torch.as_tensor(
                    box, dtype=pred_mask.dtype, device=self.model.device)
                box_torch = box_torch[None, :]    # 1, 1, 4
            else:
                box_torch = None
            sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
                points=None,
                boxes=box_torch,
                masks=prompt_mask.view(
                    1, 1, 256, 256) if self.use_mask else None,
            )
            if self.use_text:
                sparse_embeddings = torch.cat([sparse_embeddings.to(dense_embeddings),
                                               text_embed[None].to(dense_embeddings)], dim=1)
            else:
                sparse_embeddings = sparse_embeddings.to(dense_embeddings)
            low_res_masks, iou_predictions = self.model.mask_decoder(
                image_embeddings=image_embedding,
                image_pe=self.model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=self.multimask_output,
            )
            sam_mask = self.model.postprocess_masks(
                low_res_masks, input_size, original_image_size)

            if self.multimask_output:
                candidate_masks = (sam_mask[0] > 0.0).float()
                candidate_ious = compute_mask_IoU(candidate_masks.view(3, -1),
                                                  pred_mask.float().view(1, -1))[-1]
                sam_mask = sam_mask[0, candidate_ious.argmax()]
            else:
                assert sam_mask.shape[1] == 1
                sam_mask = sam_mask[0, 0]
            sam_masks.append(sam_mask)

        return torch.stack(sam_masks)

    def state_dict(self, *args, **kwargs):
        state_dict = super().state_dict(*args, **kwargs)
        return {k: v for k, v in state_dict.items() if 'image_encoder' not in k}