Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Tuple | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from metaseg.modeling import Sam | |
from metaseg.utils.amg import calculate_stability_score | |
class SamOnnxModel(nn.Module): | |
""" | |
This model should not be called directly, but is used in ONNX export. | |
It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, | |
with some functions modified to enable model tracing. Also supports extra | |
options controlling what information. See the ONNX export script for details. | |
""" | |
def __init__( | |
self, | |
model: Sam, | |
return_single_mask: bool, | |
use_stability_score: bool = False, | |
return_extra_metrics: bool = False, | |
) -> None: | |
super().__init__() | |
self.mask_decoder = model.mask_decoder | |
self.model = model | |
self.img_size = model.image_encoder.img_size | |
self.return_single_mask = return_single_mask | |
self.use_stability_score = use_stability_score | |
self.stability_score_offset = 1.0 | |
self.return_extra_metrics = return_extra_metrics | |
def resize_longest_image_size(input_image_size: torch.Tensor, longest_side: int) -> torch.Tensor: | |
input_image_size = input_image_size.to(torch.float32) | |
scale = longest_side / torch.max(input_image_size) | |
transformed_size = scale * input_image_size | |
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) | |
return transformed_size | |
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: | |
point_coords = point_coords + 0.5 | |
point_coords = point_coords / self.img_size | |
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) | |
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) | |
point_embedding = point_embedding * (point_labels != -1) | |
point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1) | |
for i in range(self.model.prompt_encoder.num_point_embeddings): | |
point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[i].weight * ( | |
point_labels == i | |
) | |
return point_embedding | |
def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: | |
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) | |
mask_embedding = mask_embedding + (1 - has_mask_input) * self.model.prompt_encoder.no_mask_embed.weight.reshape( | |
1, -1, 1, 1 | |
) | |
return mask_embedding | |
def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: | |
masks = F.interpolate( | |
masks, | |
size=(self.img_size, self.img_size), | |
mode="bilinear", | |
align_corners=False, | |
) | |
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size) | |
masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])] | |
orig_im_size = orig_im_size.to(torch.int64) | |
h, w = orig_im_size[0], orig_im_size[1] | |
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) | |
return masks | |
def select_masks( | |
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
# Determine if we should return the multiclick mask or not from the number of points. | |
# The reweighting is used to avoid control flow. | |
score_reweight = torch.tensor([[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]).to( | |
iou_preds.device | |
) | |
score = iou_preds + (num_points - 2.5) * score_reweight | |
best_idx = torch.argmax(score, dim=1) | |
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) | |
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) | |
return masks, iou_preds | |
def forward( | |
self, | |
image_embeddings: torch.Tensor, | |
point_coords: torch.Tensor, | |
point_labels: torch.Tensor, | |
mask_input: torch.Tensor, | |
has_mask_input: torch.Tensor, | |
orig_im_size: torch.Tensor, | |
): | |
sparse_embedding = self._embed_points(point_coords, point_labels) | |
dense_embedding = self._embed_masks(mask_input, has_mask_input) | |
masks, scores = self.model.mask_decoder.predict_masks( | |
image_embeddings=image_embeddings, | |
image_pe=self.model.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embedding, | |
dense_prompt_embeddings=dense_embedding, | |
) | |
if self.use_stability_score: | |
scores = calculate_stability_score(masks, self.model.mask_threshold, self.stability_score_offset) | |
if self.return_single_mask: | |
masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) | |
upscaled_masks = self.mask_postprocessing(masks, orig_im_size) | |
if self.return_extra_metrics: | |
stability_scores = calculate_stability_score( | |
upscaled_masks, self.model.mask_threshold, self.stability_score_offset | |
) | |
areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) | |
return upscaled_masks, scores, stability_scores, areas, masks | |
return upscaled_masks, scores, masks | |