File size: 5,574 Bytes
032e687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from third_parts.segment_anything import sam_model_registry
from third_parts.segment_anything.utils.transforms import ResizeLongestSide
def mask2box(mask):
ys, xs = np.where(mask > 0)
y0, y1 = ys.min(), ys.max()
x0, x1 = xs.min(), xs.max()
return np.array([x0, y0, x1+1, y1+1]) # avoid x0==x1
def compute_mask_IoU(masks, target):
temp = masks * target
intersection = temp.sum(dim=-1)
union = ((masks + target) - temp).sum(dim=-1)
return intersection, union, intersection / (union + 1e-12)
class SAMWrapper(nn.Module):
def __init__(self, model_name, checkpoint,
use_text=True, use_mask=True, use_box=True,
multimask_output=False):
super(SAMWrapper, self).__init__()
self.model = sam_model_registry[model_name](checkpoint=checkpoint)
self.model.image_encoder.requires_grad_(False)
self.transform = ResizeLongestSide(self.model.image_encoder.img_size)
self.use_text = use_text
self.use_mask = use_mask
self.use_box = use_box
self.multimask_output = multimask_output
def train(self, mode=True):
super().train(mode=mode)
self.model.image_encoder.eval()
self.training = mode
return self
@property
def dtype(self):
return self.model.dtype
@torch.no_grad()
def encode_image(self, image):
image = np.array(image.convert(self.model.image_format))
input_image = self.transform.apply_image(image)
input_image_torch = torch.as_tensor(
input_image, device=self.model.device)
transformed_image = input_image_torch.permute(
2, 0, 1).contiguous()[None, :, :, :]
original_image_size = image.shape[:2]
input_size = transformed_image.shape[-2:]
features = self.model.image_encoder(
self.model.preprocess(transformed_image))
return features, original_image_size, input_size
def generate_prompt_masks(self, masks, input_size):
pad_value = min(-1.0, masks.min().item())
masks = F.interpolate(masks[:, None].float(
), size=input_size, mode='bilinear').to(masks)
h, w = masks.shape[-2:]
masks = F.pad(masks, (0, self.model.image_encoder.img_size - w,
0, self.model.image_encoder.img_size - h), value=pad_value)
prompt_masks = F.interpolate(masks.float(), size=(
256, 256), mode='bilinear').to(masks)
return prompt_masks
def forward(self, image, pred_masks, text_embeds):
# masks are in logits
image_embedding, original_image_size, input_size = self.encode_image(
image)
if self.training:
image_embedding.requires_grad = True
prompt_masks = self.generate_prompt_masks(pred_masks, input_size)
pred_masks = F.interpolate(pred_masks.detach()[None].float().sigmoid(),
size=original_image_size, mode='bilinear')[0]
pred_masks = (pred_masks > 0.5).to(pred_masks)
sam_masks = []
for prompt_mask, pred_mask, text_embed in zip(prompt_masks, pred_masks, text_embeds):
if self.use_box:
if pred_mask.sum() > 0:
box = mask2box(pred_mask.float().cpu().numpy())
else:
h, w = original_image_size
box = np.array([0.0, 0.0, w, h])
box = self.transform.apply_boxes(box, original_image_size)
box_torch = torch.as_tensor(
box, dtype=pred_mask.dtype, device=self.model.device)
box_torch = box_torch[None, :] # 1, 1, 4
else:
box_torch = None
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=None,
boxes=box_torch,
masks=prompt_mask.view(
1, 1, 256, 256) if self.use_mask else None,
)
if self.use_text:
sparse_embeddings = torch.cat([sparse_embeddings.to(dense_embeddings),
text_embed[None].to(dense_embeddings)], dim=1)
else:
sparse_embeddings = sparse_embeddings.to(dense_embeddings)
low_res_masks, iou_predictions = self.model.mask_decoder(
image_embeddings=image_embedding,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=self.multimask_output,
)
sam_mask = self.model.postprocess_masks(
low_res_masks, input_size, original_image_size)
if self.multimask_output:
candidate_masks = (sam_mask[0] > 0.0).float()
candidate_ious = compute_mask_IoU(candidate_masks.view(3, -1),
pred_mask.float().view(1, -1))[-1]
sam_mask = sam_mask[0, candidate_ious.argmax()]
else:
assert sam_mask.shape[1] == 1
sam_mask = sam_mask[0, 0]
sam_masks.append(sam_mask)
return torch.stack(sam_masks)
def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
return {k: v for k, v in state_dict.items() if 'image_encoder' not in k}
|