|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a |
|
discrepancy, the original file should be regarded as the 'reference' version. |
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
import collections |
|
from dataclasses import dataclass |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
from transformers.activations_tf import ACT2FN |
|
from transformers.modeling_tf_outputs import TFBaseModelOutput |
|
from transformers.modeling_tf_utils import TFModelInputType, TFPreTrainedModel, shape_list, unpack_inputs |
|
from transformers.tf_utils import flatten, functional_layernorm |
|
from transformers.utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging |
|
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
_CONFIG_FOR_DOC = "SamConfig" |
|
_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge" |
|
|
|
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST = [ |
|
"facebook/sam-vit-huge", |
|
"facebook/sam-vit-large", |
|
"facebook/sam-vit-base", |
|
|
|
] |
|
|
|
|
|
@dataclass |
|
class TFSamVisionEncoderOutput(ModelOutput): |
|
""" |
|
Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection |
|
layer to the pooler_output. |
|
|
|
Args: |
|
image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): |
|
The image embeddings obtained by applying the projection layer to the pooler_output. |
|
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): |
|
Sequence of hidden-states at the output of the last layer of the model. |
|
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for |
|
the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
|
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
""" |
|
|
|
image_embeds: tf.Tensor | None = None |
|
last_hidden_state: tf.Tensor = None |
|
hidden_states: Tuple[tf.Tensor] | None = None |
|
attentions: Tuple[tf.Tensor] | None = None |
|
|
|
|
|
@dataclass |
|
class TFSamImageSegmentationOutput(ModelOutput): |
|
""" |
|
Base class for Segment-Anything model's output |
|
|
|
Args: |
|
iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`): |
|
The iou scores of the predicted masks. |
|
pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`): |
|
The predicted low resolutions masks. Needs to be post-processed by the processor |
|
vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for |
|
the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs. |
|
vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
""" |
|
|
|
iou_scores: tf.Tensor = None |
|
pred_masks: tf.Tensor = None |
|
vision_hidden_states: Tuple[tf.Tensor] | None = None |
|
vision_attentions: Tuple[tf.Tensor] | None = None |
|
mask_decoder_attentions: Tuple[tf.Tensor] | None = None |
|
|
|
|
|
class TFSamPatchEmbeddings(tf.keras.layers.Layer): |
|
""" |
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial |
|
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a |
|
Transformer. |
|
""" |
|
|
|
def __init__(self, config, **kwargs): |
|
super().__init__(**kwargs) |
|
image_size, patch_size = config.image_size, config.patch_size |
|
num_channels, hidden_size = config.num_channels, config.hidden_size |
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) |
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) |
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) |
|
self.image_size = image_size |
|
self.patch_size = patch_size |
|
self.num_channels = num_channels |
|
self.num_patches = num_patches |
|
|
|
self.projection = tf.keras.layers.Conv2D( |
|
hidden_size, kernel_size=patch_size, strides=patch_size, name="projection" |
|
) |
|
|
|
def call(self, pixel_values): |
|
batch_size, num_channels, height, width = shape_list(pixel_values) |
|
if num_channels != self.num_channels: |
|
raise ValueError( |
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration." |
|
) |
|
if height != self.image_size[0] or width != self.image_size[1]: |
|
raise ValueError( |
|
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." |
|
) |
|
embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1])) |
|
return embeddings |
|
|
|
|
|
class TFSamMLPBlock(tf.keras.layers.Layer): |
|
def __init__(self, config, **kwargs): |
|
super().__init__(**kwargs) |
|
self.lin1 = tf.keras.layers.Dense(config.mlp_dim, name="lin1") |
|
self.lin2 = tf.keras.layers.Dense(config.hidden_size, name="lin2") |
|
self.act = ACT2FN[config.hidden_act] |
|
|
|
def call(self, hidden_states: tf.Tensor) -> tf.Tensor: |
|
hidden_states = self.lin1(hidden_states) |
|
hidden_states = self.act(hidden_states) |
|
hidden_states = self.lin2(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class TFSamLayerNorm(tf.keras.layers.Layer): |
|
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. |
|
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, |
|
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). |
|
""" |
|
|
|
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs): |
|
super().__init__(**kwargs) |
|
self.eps = eps |
|
self.data_format = data_format |
|
self.normalized_shape = normalized_shape |
|
if self.data_format not in ["channels_last", "channels_first"]: |
|
raise NotImplementedError(f"Unsupported data format: {self.data_format}") |
|
|
|
def build(self, input_shape): |
|
self.weight = self.add_weight(shape=self.normalized_shape, initializer="ones", name="weight") |
|
self.bias = self.add_weight(shape=self.normalized_shape, initializer="zeros", name="bias") |
|
super().build(input_shape) |
|
|
|
def call(self, x: tf.Tensor) -> tf.Tensor: |
|
if self.data_format == "channels_last": |
|
x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1) |
|
elif self.data_format == "channels_first": |
|
x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1) |
|
return x |
|
|
|
|
|
class TFSamAttention(tf.keras.layers.Layer): |
|
""" |
|
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and |
|
values. |
|
""" |
|
|
|
def __init__(self, config, downsample_rate=None, **kwargs): |
|
super().__init__(**kwargs) |
|
self.hidden_size = config.hidden_size |
|
|
|
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate |
|
|
|
self.internal_dim = config.hidden_size // downsample_rate |
|
self.num_attention_heads = config.num_attention_heads |
|
if self.internal_dim % config.num_attention_heads != 0: |
|
raise ValueError("num_attention_heads must divide hidden_size.") |
|
|
|
self.q_proj = tf.keras.layers.Dense(self.internal_dim, name="q_proj") |
|
self.k_proj = tf.keras.layers.Dense(self.internal_dim, name="k_proj") |
|
self.v_proj = tf.keras.layers.Dense(self.internal_dim, name="v_proj") |
|
self.out_proj = tf.keras.layers.Dense(self.hidden_size, name="out_proj") |
|
|
|
def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor: |
|
batch, point_batch_size, n_tokens, channel = shape_list(hidden_states) |
|
c_per_head = channel // num_attention_heads |
|
hidden_states = tf.reshape( |
|
hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) |
|
) |
|
return tf.transpose(hidden_states, perm=[0, 2, 1, 3]) |
|
|
|
def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor: |
|
batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states) |
|
hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3]) |
|
return tf.reshape( |
|
hidden_states, |
|
(batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head), |
|
) |
|
|
|
def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor: |
|
|
|
query = self.q_proj(query) |
|
key = self.k_proj(key) |
|
value = self.v_proj(value) |
|
|
|
point_batch_size = shape_list(query)[1] |
|
|
|
query = self._separate_heads(query, self.num_attention_heads) |
|
key = self._separate_heads(key, self.num_attention_heads) |
|
value = self._separate_heads(value, self.num_attention_heads) |
|
|
|
|
|
_, _, _, c_per_head = shape_list(query) |
|
attn = tf.matmul( |
|
query, tf.transpose(key, perm=[0, 1, 3, 2]) |
|
) |
|
attn = attn / tf.math.sqrt(float(c_per_head)) |
|
attn = tf.nn.softmax(attn, axis=-1) |
|
|
|
|
|
out = tf.matmul(attn, value) |
|
out = self._recombine_heads(out, point_batch_size) |
|
out = self.out_proj(out) |
|
|
|
return out |
|
|
|
|
|
class TFSamTwoWayAttentionBlock(tf.keras.layers.Layer): |
|
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs): |
|
""" |
|
A transformer block with four layers: |
|
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on |
|
sparse inputs (4) cross attention of dense inputs -> sparse inputs |
|
|
|
Arguments: |
|
config (`SamMaskDecoderConfig`): |
|
The configuration file used to instantiate the block |
|
attention_downsample_rate (*optionalk*, int, defaults to 2): |
|
The downsample ratio of the block used to reduce the inner dim of the attention. |
|
skip_first_layer_pe (*optional*, bool, defaults to `False`): |
|
Whether or not to skip the addition of the query_point_embedding on the first layer. |
|
""" |
|
super().__init__(**kwargs) |
|
|
|
self.hidden_size = config.hidden_size |
|
self.layer_norm_eps = config.layer_norm_eps |
|
|
|
self.self_attn = TFSamAttention(config, downsample_rate=1, name="self_attn") |
|
self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm1") |
|
|
|
self.cross_attn_token_to_image = TFSamAttention( |
|
config, downsample_rate=attention_downsample_rate, name="cross_attn_token_to_image" |
|
) |
|
self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm2") |
|
|
|
self.mlp = TFSamMLPBlock(config, name="mlp") |
|
self.layer_norm3 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm3") |
|
|
|
self.layer_norm4 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm4") |
|
self.cross_attn_image_to_token = TFSamAttention( |
|
config, downsample_rate=attention_downsample_rate, name="cross_attn_image_to_token" |
|
) |
|
|
|
self.skip_first_layer_pe = skip_first_layer_pe |
|
|
|
def call( |
|
self, |
|
queries: tf.Tensor, |
|
keys: tf.Tensor, |
|
query_point_embedding: tf.Tensor, |
|
key_point_embedding: tf.Tensor, |
|
output_attentions: bool = False, |
|
): |
|
|
|
if self.skip_first_layer_pe: |
|
queries = self.self_attn(query=queries, key=queries, value=queries) |
|
else: |
|
query = queries + query_point_embedding |
|
attn_out = self.self_attn(query=query, key=query, value=queries) |
|
queries = queries + attn_out |
|
queries = self.layer_norm1(queries) |
|
|
|
|
|
query = queries + query_point_embedding |
|
key = keys + key_point_embedding |
|
|
|
attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys) |
|
queries = queries + attn_out |
|
|
|
queries = self.layer_norm2(queries) |
|
|
|
|
|
mlp_out = self.mlp(queries) |
|
queries = queries + mlp_out |
|
queries = self.layer_norm3(queries) |
|
|
|
|
|
query = queries + query_point_embedding |
|
key = keys + key_point_embedding |
|
|
|
attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries) |
|
keys = keys + attn_out |
|
|
|
keys = self.layer_norm4(keys) |
|
|
|
outputs = (queries, keys) |
|
|
|
if output_attentions: |
|
outputs = outputs + (attn_out,) |
|
else: |
|
outputs = outputs + (None,) |
|
|
|
return outputs |
|
|
|
|
|
class TFSamTwoWayTransformer(tf.keras.layers.Layer): |
|
def __init__(self, config: SamMaskDecoderConfig, **kwargs): |
|
super().__init__(**kwargs) |
|
self.config = config |
|
|
|
self.num_hidden_layers = config.num_hidden_layers |
|
self.layers = [] |
|
|
|
for i in range(self.num_hidden_layers): |
|
self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_._{i}")) |
|
|
|
self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image") |
|
self.layer_norm_final_attn = tf.keras.layers.LayerNormalization( |
|
epsilon=config.layer_norm_eps, name="layer_norm_final_attn" |
|
) |
|
|
|
def call( |
|
self, |
|
point_embeddings: tf.Tensor, |
|
image_embeddings: tf.Tensor, |
|
image_positional_embeddings: tf.Tensor, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, TFBaseModelOutput]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
all_attentions = () |
|
|
|
if image_embeddings is None: |
|
raise ValueError("You have to specify an image_embedding") |
|
|
|
image_embeddings = tf.transpose(flatten(image_embeddings, 2), perm=(0, 2, 1))[:, None] |
|
image_positional_embeddings = tf.transpose(flatten(image_positional_embeddings, 2), (0, 2, 1))[:, None] |
|
|
|
|
|
queries = point_embeddings |
|
keys = image_embeddings |
|
|
|
|
|
for layer in self.layers: |
|
queries, keys, attention_outputs = layer( |
|
queries=queries, |
|
keys=keys, |
|
query_point_embedding=point_embeddings, |
|
key_point_embedding=image_positional_embeddings, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
if output_attentions: |
|
all_attentions = all_attentions + (attention_outputs,) |
|
|
|
|
|
query = queries + point_embeddings |
|
key = keys + image_positional_embeddings |
|
|
|
attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys) |
|
|
|
queries = queries + attn_out |
|
queries = self.layer_norm_final_attn(queries) |
|
return queries, keys, all_attentions |
|
|
|
|
|
class TFSamFeedForward(tf.keras.layers.Layer): |
|
def __init__( |
|
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.num_layers = num_layers |
|
self.activation = tf.keras.layers.ReLU() |
|
self.proj_in = tf.keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in") |
|
self.proj_out = tf.keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out") |
|
self.layers = [ |
|
tf.keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layers_._{i}") |
|
for i in range(num_layers - 2) |
|
] |
|
self.sigmoid_output = sigmoid_output |
|
|
|
def call(self, hidden_states): |
|
hidden_states = self.proj_in(hidden_states) |
|
hidden_states = self.activation(hidden_states) |
|
for layer in self.layers: |
|
hidden_states = self.activation(layer(hidden_states)) |
|
|
|
hidden_states = self.proj_out(hidden_states) |
|
if self.sigmoid_output: |
|
hidden_states = tf.sigmoid(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class TFSamMaskDecoder(tf.keras.layers.Layer): |
|
def __init__(self, config: SamMaskDecoderConfig, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.hidden_size = config.hidden_size |
|
|
|
self.num_multimask_outputs = config.num_multimask_outputs |
|
self.num_mask_tokens = config.num_multimask_outputs + 1 |
|
|
|
self.transformer = TFSamTwoWayTransformer(config, name="transformer") |
|
|
|
self.upscale_conv1 = tf.keras.layers.Conv2DTranspose( |
|
self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1", data_format="channels_first" |
|
) |
|
self.upscale_conv2 = tf.keras.layers.Conv2DTranspose( |
|
self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first" |
|
) |
|
self.upscale_layer_norm = TFSamLayerNorm( |
|
self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm" |
|
) |
|
self.activation = tf.nn.gelu |
|
|
|
mlps_list = [] |
|
for i in range(self.num_mask_tokens): |
|
mlps_list += [ |
|
TFSamFeedForward( |
|
self.hidden_size, |
|
self.hidden_size, |
|
self.hidden_size // 8, |
|
3, |
|
name=f"output_hypernetworks_mlps_._{i}", |
|
) |
|
] |
|
self.output_hypernetworks_mlps = mlps_list |
|
|
|
self.iou_prediction_head = TFSamFeedForward( |
|
self.hidden_size, |
|
config.iou_head_hidden_dim, |
|
self.num_mask_tokens, |
|
config.iou_head_depth, |
|
name="iou_prediction_head", |
|
) |
|
|
|
def build(self, input_shape): |
|
self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True) |
|
self.mask_tokens = self.add_weight( |
|
shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True |
|
) |
|
super().build(input_shape) |
|
|
|
def call( |
|
self, |
|
image_embeddings: tf.Tensor, |
|
image_positional_embeddings: tf.Tensor, |
|
sparse_prompt_embeddings: tf.Tensor, |
|
dense_prompt_embeddings: tf.Tensor, |
|
multimask_output: bool, |
|
output_attentions: Optional[bool] = None, |
|
) -> Tuple[tf.Tensor, tf.Tensor]: |
|
batch_size, num_channels, height, width = shape_list(image_embeddings) |
|
point_batch_size = tf.math.maximum(1, tf.shape(sparse_prompt_embeddings)[1]) |
|
|
|
output_tokens = tf.concat([self.iou_token, self.mask_tokens], axis=0) |
|
output_tokens = tf.tile( |
|
output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1] |
|
) |
|
|
|
|
|
|
|
|
|
if shape_list(sparse_prompt_embeddings)[1] != 0: |
|
tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2) |
|
else: |
|
tokens = output_tokens |
|
point_embeddings = tf.cast(tokens, self.iou_token.dtype) |
|
|
|
image_embeddings = image_embeddings + dense_prompt_embeddings |
|
image_embeddings = tf.tile(image_embeddings, [point_batch_size, 1, 1, 1]) |
|
image_positional_embeddings = tf.tile(image_positional_embeddings, [point_batch_size, 1, 1, 1]) |
|
|
|
point_embedding, image_embeddings, attentions = self.transformer( |
|
point_embeddings=point_embeddings, |
|
image_embeddings=image_embeddings, |
|
image_positional_embeddings=image_positional_embeddings, |
|
output_attentions=output_attentions, |
|
) |
|
iou_token_out = point_embedding[:, :, 0, :] |
|
mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :] |
|
|
|
image_embeddings = tf.transpose(image_embeddings, perm=(0, 1, 3, 2)) |
|
image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width]) |
|
|
|
upscaled_embedding = self.upscale_conv1(image_embeddings) |
|
upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) |
|
upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding)) |
|
|
|
hyper_in_list = [] |
|
for i in range(self.num_mask_tokens): |
|
current_mlp = self.output_hypernetworks_mlps[i] |
|
hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] |
|
hyper_in = tf.stack(hyper_in_list, axis=2) |
|
|
|
_, num_channels, height, width = shape_list(upscaled_embedding) |
|
upscaled_embedding = tf.reshape( |
|
upscaled_embedding, [batch_size, point_batch_size, num_channels, height * width] |
|
) |
|
masks = tf.reshape(hyper_in @ upscaled_embedding, [batch_size, point_batch_size, -1, height, width]) |
|
|
|
iou_pred = self.iou_prediction_head(iou_token_out) |
|
|
|
if multimask_output: |
|
mask_slice = slice(1, None) |
|
else: |
|
mask_slice = slice(0, 1) |
|
masks = masks[:, :, mask_slice, :, :] |
|
iou_pred = iou_pred[:, :, mask_slice] |
|
|
|
outputs = (masks, iou_pred) |
|
|
|
if output_attentions: |
|
outputs = outputs + (attentions,) |
|
else: |
|
outputs = outputs + (None,) |
|
|
|
return outputs |
|
|
|
|
|
class TFSamPositionalEmbedding(tf.keras.layers.Layer): |
|
def __init__(self, config, **kwargs): |
|
super().__init__(**kwargs) |
|
self.scale = config.hidden_size // 2 |
|
self.config = config |
|
|
|
def build(self, input_shape): |
|
|
|
self.positional_embedding = self.add_weight( |
|
name="positional_embedding", |
|
shape=(2, self.config.num_pos_feats), |
|
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.scale), |
|
trainable=False, |
|
) |
|
super().build(input_shape) |
|
|
|
def call(self, input_coords, input_shape=None): |
|
"""Positionally encode points that are normalized to [0,1].""" |
|
coordinates = tf.identity(input_coords) |
|
|
|
if input_shape is not None: |
|
coordinates = tf.stack( |
|
[ |
|
tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1], |
|
tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0], |
|
], |
|
axis=-1, |
|
) |
|
|
|
|
|
coordinates = 2 * coordinates - 1 |
|
coordinates = tf.cast(coordinates, self.positional_embedding.dtype) |
|
coordinates = tf.matmul(coordinates, self.positional_embedding) |
|
coordinates = 2 * np.pi * coordinates |
|
|
|
return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1) |
|
|
|
|
|
class TFSamMaskEmbedding(tf.keras.layers.Layer): |
|
def __init__(self, config: SamPromptEncoderConfig, **kwargs): |
|
super().__init__(**kwargs) |
|
self.mask_input_channels = config.mask_input_channels // 4 |
|
self.activation = ACT2FN[config.hidden_act] |
|
self.conv1 = tf.keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1") |
|
self.conv2 = tf.keras.layers.Conv2D(config.mask_input_channels, kernel_size=2, strides=2, name="conv2") |
|
self.conv3 = tf.keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3") |
|
self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1") |
|
self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2") |
|
|
|
def call(self, masks): |
|
masks = tf.transpose(masks, perm=(0, 2, 3, 1)) |
|
hidden_states = self.conv1(masks) |
|
hidden_states = self.layer_norm1(hidden_states) |
|
hidden_states = self.activation(hidden_states) |
|
|
|
hidden_states = self.conv2(hidden_states) |
|
hidden_states = self.layer_norm2(hidden_states) |
|
hidden_states = self.activation(hidden_states) |
|
dense_embeddings = self.conv3(hidden_states) |
|
dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2)) |
|
return dense_embeddings |
|
|
|
def build(self, input_shape): |
|
|
|
conv1_shape = [None, None, None, 1] |
|
conv2_shape = [None, None, None, self.mask_input_channels] |
|
conv3_shape = [None, None, None, self.mask_input_channels * 4] |
|
layer_norm1_shape = [None, None, None, self.mask_input_channels] |
|
layer_norm2_shape = [None, None, None, self.mask_input_channels * 4] |
|
with tf.name_scope("conv1"): |
|
self.conv1.build(conv1_shape) |
|
with tf.name_scope("conv2"): |
|
self.conv2.build(conv2_shape) |
|
with tf.name_scope("conv3"): |
|
self.conv3.build(conv3_shape) |
|
with tf.name_scope("layer_norm1"): |
|
self.layer_norm1.build(layer_norm1_shape) |
|
with tf.name_scope("layer_norm2"): |
|
self.layer_norm2.build(layer_norm2_shape) |
|
super().build(input_shape) |
|
|
|
|
|
class TFSamPromptEncoder(tf.keras.layers.Layer): |
|
def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs): |
|
super().__init__(**kwargs) |
|
self.shared_embedding = shared_patch_embedding |
|
self.mask_embed = TFSamMaskEmbedding(config, name="mask_embed") |
|
self.no_mask_embed = None |
|
|
|
self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size) |
|
self.input_image_size = config.image_size |
|
|
|
self.point_embed = [] |
|
self.hidden_size = config.hidden_size |
|
self.not_a_point_embed = None |
|
self.config = config |
|
|
|
def build(self, input_shape): |
|
self.no_mask_embed = self.add_weight( |
|
name="no_mask_embed.weight", |
|
shape=(1, self.hidden_size), |
|
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), |
|
trainable=True, |
|
) |
|
self.point_embed = [ |
|
self.add_weight( |
|
name=f"point_embed_._{i}.weight", |
|
shape=(1, self.hidden_size), |
|
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), |
|
trainable=True, |
|
) |
|
for i in range(self.config.num_point_embeddings) |
|
] |
|
self.not_a_point_embed = self.add_weight( |
|
name="not_a_point_embed.weight", |
|
shape=(1, self.hidden_size), |
|
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02), |
|
trainable=True, |
|
) |
|
with tf.name_scope("mask_embed"): |
|
|
|
self.mask_embed.build( |
|
(None, self.config.mask_input_channels, self.config.image_size, self.config.image_size) |
|
) |
|
super().build(input_shape) |
|
|
|
def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor: |
|
"""Embeds point prompts.""" |
|
points = points + 0.5 |
|
if pad: |
|
target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1]) |
|
target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1) |
|
padding_point = tf.zeros(target_point_shape, dtype=points.dtype) |
|
padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype) |
|
points = tf.concat([points, padding_point], axis=2) |
|
labels = tf.concat([labels, padding_label], axis=2) |
|
input_shape = (self.input_image_size, self.input_image_size) |
|
point_embedding = self.shared_embedding(points, input_shape) |
|
|
|
point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed[0], point_embedding) |
|
|
|
point_embedding = tf.where( |
|
labels[..., None] != -10, |
|
point_embedding, |
|
tf.zeros_like(point_embedding), |
|
) |
|
point_embedding = tf.where( |
|
(labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding |
|
) |
|
point_embedding = tf.where( |
|
(labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding |
|
) |
|
return point_embedding |
|
|
|
def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor: |
|
"""Embeds box prompts.""" |
|
boxes = boxes + 0.5 |
|
batch_size, nb_boxes = shape_list(boxes)[:2] |
|
coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2)) |
|
input_shape = (self.input_image_size, self.input_image_size) |
|
corner_embedding = self.shared_embedding(coords, input_shape) |
|
corner_embedding += tf.where( |
|
tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0, |
|
self.point_embed[2][0], |
|
self.point_embed[3][0], |
|
) |
|
return corner_embedding |
|
|
|
def call( |
|
self, |
|
batch_size: Optional[int], |
|
input_points: Optional[Tuple[tf.Tensor, tf.Tensor]], |
|
input_labels: tf.Tensor | None, |
|
input_boxes: tf.Tensor | None, |
|
input_masks: tf.Tensor | None, |
|
) -> Tuple[tf.Tensor, tf.Tensor]: |
|
""" |
|
Embeds different types of prompts, returning both sparse and dense embeddings. |
|
|
|
Args: |
|
points (`tf.Tensor`, *optional*): |
|
point coordinates and labels to embed. |
|
boxes (`tf.Tensor`, *optional*): |
|
boxes to embed |
|
masks (`tf.Tensor`, *optional*): |
|
masks to embed |
|
""" |
|
sparse_embeddings = None |
|
if input_points is not None: |
|
batch_size, point_batch_size = shape_list(input_points)[:2] |
|
if input_labels is None: |
|
raise ValueError("If points are provided, labels must also be provided.") |
|
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) |
|
sparse_embeddings = tf.zeros( |
|
(batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype |
|
) |
|
sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2) |
|
if input_boxes is not None: |
|
batch_size = shape_list(input_boxes)[0] |
|
box_embeddings = self._embed_boxes(input_boxes) |
|
if sparse_embeddings is None: |
|
sparse_embeddings = box_embeddings |
|
else: |
|
sparse_embeddings = tf.concat([sparse_embeddings, box_embeddings], axis=2) |
|
if input_masks is not None: |
|
dense_embeddings = self.mask_embed(input_masks) |
|
else: |
|
dense_embeddings = self.no_mask_embed[0] |
|
dense_embeddings = tf.reshape(dense_embeddings, (1, -1, 1, 1)) |
|
dense_embeddings = tf.tile( |
|
dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1]) |
|
) |
|
if sparse_embeddings is None: |
|
sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype) |
|
|
|
return sparse_embeddings, dense_embeddings |
|
|
|
|
|
class TFSamVisionAttention(tf.keras.layers.Layer): |
|
"""Multi-head Attention block with relative position embeddings.""" |
|
|
|
def __init__(self, config, window_size, **kwargs): |
|
super().__init__(**kwargs) |
|
input_size = ( |
|
(config.image_size // config.patch_size, config.image_size // config.patch_size) |
|
if window_size == 0 |
|
else (window_size, window_size) |
|
) |
|
self.input_size = input_size |
|
|
|
self.num_attention_heads = config.num_attention_heads |
|
head_dim = config.hidden_size // config.num_attention_heads |
|
self.head_dim = head_dim |
|
self.scale = head_dim**-0.5 |
|
self.dropout = config.attention_dropout |
|
|
|
self.qkv = tf.keras.layers.Dense(config.hidden_size * 3, use_bias=config.qkv_bias, name="qkv") |
|
self.proj = tf.keras.layers.Dense(config.hidden_size, name="proj") |
|
|
|
self.use_rel_pos = config.use_rel_pos |
|
if self.use_rel_pos: |
|
if input_size is None: |
|
raise ValueError("Input size must be provided if using relative positional encoding.") |
|
|
|
def build(self, input_shape): |
|
if self.input_size is not None: |
|
|
|
self.rel_pos_h = self.add_weight( |
|
shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h" |
|
) |
|
self.rel_pos_w = self.add_weight( |
|
shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w" |
|
) |
|
super().build(input_shape) |
|
|
|
def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor: |
|
""" |
|
Get relative positional embeddings according to the relative positions of |
|
query and key sizes. |
|
|
|
Args: |
|
q_size (int): |
|
size of the query. |
|
k_size (int): |
|
size of key k. |
|
rel_pos (`tf.Tensor`): |
|
relative position embeddings (L, channel). |
|
|
|
Returns: |
|
Extracted positional embeddings according to relative positions. |
|
""" |
|
max_rel_dist = int(2 * max(q_size, k_size) - 1) |
|
|
|
if rel_pos.shape[0] != max_rel_dist: |
|
|
|
rel_pos_resized = tf.image.resize( |
|
tf.reshape(rel_pos, (1, rel_pos.shape[0], -1)), |
|
size=(max_rel_dist, rel_pos.shape[1]), |
|
method="bilinear", |
|
) |
|
rel_pos_resized = tf.reshape(rel_pos_resized, (-1, max_rel_dist)) |
|
else: |
|
rel_pos_resized = rel_pos |
|
|
|
|
|
q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), 1) * max(k_size / q_size, 1.0) |
|
k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), 0) * max(q_size / k_size, 1.0) |
|
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) |
|
|
|
return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32)) |
|
|
|
def add_decomposed_rel_pos( |
|
self, |
|
attn: tf.Tensor, |
|
query: tf.Tensor, |
|
rel_pos_h: tf.Tensor, |
|
rel_pos_w: tf.Tensor, |
|
q_size: Tuple[int, int], |
|
k_size: Tuple[int, int], |
|
) -> tf.Tensor: |
|
""" |
|
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. |
|
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py |
|
|
|
Args: |
|
attn (`tf.Tensor`): |
|
attention map. |
|
query (`tf.Tensor`): |
|
query q in the attention layer with shape (batch_size, query_height * query_width, channel). |
|
rel_pos_h (`tf.Tensor`): |
|
relative position embeddings (Lh, channel) for height axis. |
|
rel_pos_w (`tf.Tensor`): |
|
relative position embeddings (Lw, channel) for width axis. |
|
q_size (tuple): |
|
spatial sequence size of query q with (query_height, query_width). |
|
k_size (tuple): |
|
spatial sequence size of key k with (key_height, key_width). |
|
|
|
Returns: |
|
attn (`tf.Tensor`): |
|
attention map with added relative positional embeddings. |
|
""" |
|
query_height, query_width = q_size |
|
key_height, key_width = k_size |
|
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) |
|
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) |
|
|
|
batch_size, _, dim = shape_list(query) |
|
reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim)) |
|
rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) |
|
rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) |
|
attn = tf.reshape(attn, (batch_size, query_height, query_width, key_height, key_width)) |
|
attn = attn + tf.expand_dims(rel_h, axis=-1) + tf.expand_dims(rel_w, axis=-2) |
|
attn = tf.reshape(attn, (batch_size, query_height * query_width, key_height * key_width)) |
|
return attn |
|
|
|
def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor: |
|
batch_size, height, width, _ = shape_list(hidden_states) |
|
|
|
qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1)) |
|
qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4)) |
|
|
|
query, key, value = tf.unstack( |
|
tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0 |
|
) |
|
attn_weights = tf.matmul(query * self.scale, key, transpose_b=True) |
|
|
|
if self.use_rel_pos: |
|
attn_weights = self.add_decomposed_rel_pos( |
|
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) |
|
) |
|
|
|
attn_weights = tf.nn.softmax(attn_weights, axis=-1) |
|
|
|
if training: |
|
attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout) |
|
else: |
|
attn_probs = attn_weights |
|
|
|
attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1)) |
|
attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4)) |
|
attn_output = tf.reshape(attn_output, (batch_size, height, width, -1)) |
|
|
|
attn_output = self.proj(attn_output) |
|
|
|
if output_attentions: |
|
outputs = (attn_output, attn_weights) |
|
else: |
|
outputs = (attn_output, None) |
|
|
|
return outputs |
|
|
|
|
|
class TFSamVisionLayer(tf.keras.layers.Layer): |
|
def __init__(self, config, window_size, **kwargs): |
|
super().__init__(**kwargs) |
|
self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1") |
|
self.attn = TFSamVisionAttention(config, window_size, name="attn") |
|
self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2") |
|
self.mlp = TFSamMLPBlock(config, name="mlp") |
|
self.window_size = window_size |
|
|
|
def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> Tuple[tf.Tensor, Tuple[int, int]]: |
|
batch_size, height, width, channel = shape_list(hidden_states) |
|
|
|
pad_h = (window_size - height % window_size) % window_size |
|
pad_w = (window_size - width % window_size) % window_size |
|
if pad_h > 0 or pad_w > 0: |
|
hidden_states = tf.pad(hidden_states, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]]) |
|
pad_height, pad_width = height + pad_h, width + pad_w |
|
|
|
hidden_states = tf.reshape( |
|
hidden_states, |
|
[batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel], |
|
) |
|
windows = tf.reshape( |
|
tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, channel] |
|
) |
|
return windows, (pad_height, pad_width) |
|
|
|
def window_unpartition( |
|
self, windows: tf.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int] |
|
) -> tf.Tensor: |
|
pad_height, pad_width = padding_shape |
|
height, width = original_shape |
|
batch_size = shape_list(windows)[0] // (pad_height * pad_width // window_size // window_size) |
|
hidden_states = tf.reshape( |
|
windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1] |
|
) |
|
hidden_states = tf.reshape( |
|
tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [batch_size, pad_height, pad_width, -1] |
|
) |
|
|
|
if pad_height > height or pad_width > width: |
|
hidden_states = hidden_states[:, :height, :width, :] |
|
return hidden_states |
|
|
|
def call( |
|
self, |
|
hidden_states: tf.Tensor, |
|
output_attentions: Optional[bool] = False, |
|
training: Optional[bool] = False, |
|
) -> Tuple[tf.Tensor]: |
|
residual = hidden_states |
|
|
|
hidden_states = self.layer_norm1(hidden_states) |
|
if self.window_size > 0: |
|
height, width = hidden_states.shape[1], hidden_states.shape[2] |
|
hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) |
|
|
|
hidden_states, attn_weights = self.attn( |
|
hidden_states=hidden_states, |
|
output_attentions=output_attentions, |
|
training=training, |
|
) |
|
if self.window_size > 0: |
|
hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) |
|
|
|
hidden_states = residual + hidden_states |
|
layernorm_output = self.layer_norm2(hidden_states) |
|
hidden_states = hidden_states + self.mlp(layernorm_output) |
|
|
|
outputs = (hidden_states,) |
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
class TFSamVisionNeck(tf.keras.layers.Layer): |
|
def __init__(self, config: SamVisionConfig, **kwargs): |
|
super().__init__(**kwargs) |
|
self.config = config |
|
|
|
self.conv1 = tf.keras.layers.Conv2D( |
|
config.output_channels, |
|
kernel_size=1, |
|
use_bias=False, |
|
name="conv1", |
|
) |
|
self.layer_norm1 = TFSamLayerNorm(config.output_channels, name="layer_norm1") |
|
self.conv2 = tf.keras.layers.Conv2D( |
|
config.output_channels, |
|
kernel_size=3, |
|
padding="same", |
|
use_bias=False, |
|
name="conv2", |
|
) |
|
self.layer_norm2 = TFSamLayerNorm(config.output_channels, name="layer_norm2") |
|
|
|
def call(self, hidden_states): |
|
hidden_states = self.conv1(hidden_states) |
|
hidden_states = self.layer_norm1(hidden_states) |
|
|
|
hidden_states = self.conv2(hidden_states) |
|
hidden_states = self.layer_norm2(hidden_states) |
|
hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2]) |
|
return hidden_states |
|
|
|
|
|
class TFSamVisionEncoder(tf.keras.layers.Layer): |
|
def __init__(self, config: SamVisionConfig, **kwargs): |
|
super().__init__(**kwargs) |
|
self.config = config |
|
self.image_size = config.image_size |
|
|
|
self.patch_embed = TFSamPatchEmbeddings(config, name="patch_embed") |
|
|
|
self.pos_embed = None |
|
|
|
self.layers = [] |
|
for i in range(config.num_hidden_layers): |
|
layer = TFSamVisionLayer( |
|
config, |
|
window_size=config.window_size if i not in config.global_attn_indexes else 0, |
|
name=f"layers_._{i}", |
|
) |
|
self.layers.append(layer) |
|
|
|
self.neck = TFSamVisionNeck(config, name="neck") |
|
|
|
def build(self, input_shape): |
|
if self.config.use_abs_pos: |
|
|
|
self.pos_embed = self.add_weight( |
|
shape=[ |
|
1, |
|
self.config.image_size // self.config.patch_size, |
|
self.config.image_size // self.config.patch_size, |
|
self.config.hidden_size, |
|
], |
|
initializer="zeros", |
|
trainable=True, |
|
name="pos_embed", |
|
) |
|
super().build(input_shape) |
|
|
|
def get_input_embeddings(self): |
|
return self.patch_embed |
|
|
|
def call( |
|
self, |
|
pixel_values: tf.Tensor | None = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
training: Optional[bool] = False, |
|
) -> Union[Tuple, TFSamVisionEncoderOutput]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if pixel_values is None: |
|
raise ValueError("You have to specify pixel_values") |
|
|
|
hidden_states = self.patch_embed(pixel_values) |
|
if self.pos_embed is not None: |
|
hidden_states = hidden_states + self.pos_embed |
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attentions = () if output_attentions else None |
|
|
|
for i, layer_module in enumerate(self.layers): |
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions, training=training) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_self_attentions = all_self_attentions + (layer_outputs[1],) |
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
hidden_states = self.neck(hidden_states) |
|
|
|
if not return_dict: |
|
outputs = (hidden_states,) |
|
if output_hidden_states: |
|
outputs = outputs + (all_hidden_states,) |
|
if output_attentions: |
|
outputs = outputs + (all_self_attentions,) |
|
return outputs |
|
|
|
return TFSamVisionEncoderOutput( |
|
last_hidden_state=hidden_states, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attentions, |
|
) |
|
|
|
|
|
class TFSamPreTrainedModel(TFPreTrainedModel): |
|
config_class = SamConfig |
|
base_model_prefix = "sam" |
|
main_input_name = "pixel_values" |
|
|
|
|
|
SAM_START_DOCSTRING = r""" |
|
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a TensorFlow [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) |
|
subclass. Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to |
|
general usage and behavior. |
|
|
|
Parameters: |
|
config ([`SamConfig`]): Model configuration class with all the parameters of the model. |
|
Initializing with a config file does not load the weights associated with the model, only the |
|
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. |
|
""" |
|
|
|
|
|
SAM_INPUTS_DOCSTRING = r""" |
|
Args: |
|
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): |
|
Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for |
|
details. |
|
input_points (`tf.Tensor` of shape `(batch_size, num_points, 2)`): |
|
Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much |
|
better results. The points can be obtained by passing a list of list of list to the processor that will |
|
create corresponding `tf` tensors of dimension 4. The first dimension is the image batch size, the second |
|
dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per |
|
input point), the third dimension is the number of points per segmentation mask (it is possible to pass |
|
multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) |
|
coordinates of the point. If a different number of points is passed either for each image, or for each |
|
mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the |
|
computation of the embedding will be skipped for these points using the labels. |
|
input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points)`): |
|
Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the |
|
official implementation, there are 3 types of labels |
|
|
|
- `1`: the point is a point that contains the object of interest |
|
- `0`: the point is a point that does not contain the object of interest |
|
- `-1`: the point corresponds to the background |
|
|
|
We added the label: |
|
|
|
- `-10`: the point is a padding point, thus should be ignored by the prompt encoder |
|
|
|
The padding labels should be automatically done by the processor. |
|
input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes, 4)`): |
|
Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to |
|
much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, |
|
that will generate a `tf` tensor, with each dimension corresponding respectively to the image batch size, |
|
the number of boxes per image and the coordinates of the top left and botton right point of the box. In the |
|
order (`x1`, `y1`, `x2`, `y2`): |
|
|
|
- `x1`: the x coordinate of the top left point of the input box |
|
- `y1`: the y coordinate of the top left point of the input box |
|
- `x2`: the x coordinate of the bottom right point of the input box |
|
- `y2`: the y coordinate of the bottom right point of the input box |
|
|
|
input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`): |
|
SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to |
|
generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be |
|
manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). |
|
|
|
image_embeddings (`tf.Tensor` of shape `(batch_size, output_channels, window_size, window_size)`): |
|
Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory |
|
efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` |
|
method, and then feed them to the `call` method instead of feeding the `pixel_values`. |
|
multimask_output (`bool`, *optional*): |
|
In the original implementation and paper, the model always outputs 3 masks per image (or per point / per |
|
bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the |
|
"best" mask, by specifying `multimask_output=False`. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
|
|
@add_start_docstrings( |
|
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ", |
|
" optional 2D location and bounding boxes.", |
|
SAM_START_DOCSTRING, |
|
) |
|
class TFSamModel(TFSamPreTrainedModel): |
|
_keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"] |
|
|
|
def __init__(self, config, **kwargs): |
|
super().__init__(config, **kwargs) |
|
self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding") |
|
|
|
self.vision_encoder = TFSamVisionEncoder(config.vision_config, name="vision_encoder") |
|
self.prompt_encoder = TFSamPromptEncoder( |
|
config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder" |
|
) |
|
self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder") |
|
self.config = config |
|
|
|
def get_input_embeddings(self): |
|
return self.vision_encoder.get_input_embeddings() |
|
|
|
def get_image_wide_positional_embeddings(self): |
|
size = self.config.prompt_encoder_config.image_embedding_size |
|
grid = tf.ones((size, size)) |
|
y_embed = tf.math.cumsum(grid, axis=0) - 0.5 |
|
x_embed = tf.math.cumsum(grid, axis=1) - 0.5 |
|
y_embed = y_embed / size |
|
x_embed = x_embed / size |
|
|
|
positional_embedding = self.shared_image_embedding(tf.stack([x_embed, y_embed], axis=-1)) |
|
return tf.expand_dims(tf.transpose(positional_embedding, perm=[2, 0, 1]), axis=0) |
|
|
|
def get_image_embeddings( |
|
self, |
|
pixel_values, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
r""" |
|
Returns the image embeddings by passing the pixel values through the vision encoder. |
|
|
|
Args: |
|
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): |
|
Input pixel values |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.TFModelOutput`] instead of a plain tuple. |
|
|
|
""" |
|
vision_output = self.vision_encoder( |
|
pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
image_embeddings = vision_output[0] |
|
return image_embeddings |
|
|
|
def get_prompt_embeddings( |
|
self, |
|
input_points: tf.Tensor | None = None, |
|
input_labels: tf.Tensor | None = None, |
|
input_boxes: tf.Tensor | None = None, |
|
input_masks: tf.Tensor | None = None, |
|
): |
|
r""" |
|
Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. |
|
|
|
Args: |
|
input_points (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): |
|
Optional input points for the prompt encoder. The padding of the point is automatically done by the |
|
processor. `point_batch_size` refers to the number of masks that we want the model to predict per |
|
point. The model will output `point_batch_size` times 3 masks in total. |
|
input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): |
|
Optional input labels for the prompt encoder. The padding of the labels is automatically done by the |
|
processor, or can be fed by the user. |
|
input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`): |
|
Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the |
|
processor. users can also pass manually the input boxes. |
|
input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`): |
|
Optional input masks for the prompt encoder. |
|
""" |
|
prompt_output = self.prompt_encoder( |
|
input_points=input_points, |
|
input_labels=input_labels, |
|
input_boxes=input_boxes, |
|
input_masks=input_masks, |
|
) |
|
return prompt_output |
|
|
|
@unpack_inputs |
|
@add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING) |
|
def call( |
|
self, |
|
pixel_values: TFModelInputType | None = None, |
|
input_points: tf.Tensor | None = None, |
|
input_labels: tf.Tensor | None = None, |
|
input_boxes: tf.Tensor | None = None, |
|
input_masks: tf.Tensor | None = None, |
|
image_embeddings: tf.Tensor | None = None, |
|
multimask_output: bool = True, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict=None, |
|
training=False, |
|
**kwargs, |
|
) -> List[Dict[str, tf.Tensor]]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if pixel_values is None and image_embeddings is None: |
|
raise ValueError("Either pixel_values or image_embeddings must be provided.") |
|
|
|
if pixel_values is not None and image_embeddings is not None: |
|
raise ValueError("Only one of pixel_values and image_embeddings can be provided.") |
|
|
|
if input_points is not None and len(input_points.shape) != 4: |
|
raise ValueError( |
|
"The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.", |
|
" got {}.".format(input_points.shape), |
|
) |
|
if input_boxes is not None and len(input_boxes.shape) != 3: |
|
raise ValueError( |
|
"The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.", |
|
" got {}.".format(input_boxes.shape), |
|
) |
|
if input_points is not None and input_boxes is not None: |
|
point_batch_size = shape_list(input_points)[1] |
|
box_batch_size = shape_list(input_boxes)[1] |
|
if point_batch_size != box_batch_size: |
|
raise ValueError( |
|
"You should provide as many bounding boxes as input points per box. Got {} and {}.".format( |
|
point_batch_size, box_batch_size |
|
) |
|
) |
|
if pixel_values is not None: |
|
|
|
pixel_values = tf.ensure_shape( |
|
pixel_values, |
|
[ |
|
None, |
|
self.config.vision_config.num_channels, |
|
self.config.vision_config.image_size, |
|
self.config.vision_config.image_size, |
|
], |
|
) |
|
image_positional_embeddings = self.get_image_wide_positional_embeddings() |
|
|
|
batch_size = shape_list(pixel_values)[0] if pixel_values is not None else shape_list(image_embeddings)[0] |
|
image_positional_embeddings = tf.repeat(image_positional_embeddings, batch_size, axis=0) |
|
|
|
vision_attentions = None |
|
vision_hidden_states = None |
|
|
|
if pixel_values is not None: |
|
vision_outputs = self.vision_encoder( |
|
pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=True, |
|
training=training, |
|
) |
|
image_embeddings = vision_outputs["last_hidden_state"] |
|
|
|
if output_hidden_states: |
|
vision_hidden_states = vision_outputs["hidden_states"] |
|
if output_attentions: |
|
vision_attentions = vision_outputs["attentions"] |
|
|
|
if input_points is not None and input_labels is None: |
|
input_labels = tf.ones_like(input_points[:, :, :, 0], dtype=tf.int32) |
|
|
|
if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]: |
|
raise ValueError( |
|
"The batch size of the image embeddings and the input points must be the same. ", |
|
"Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]), |
|
" if you want to pass multiple points for the same image, make sure that you passed ", |
|
" input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ", |
|
" input_labels of shape (batch_size, point_batch_size, num_points_per_image)", |
|
) |
|
|
|
sparse_embeddings, dense_embeddings = self.prompt_encoder( |
|
batch_size=shape_list(image_embeddings)[0], |
|
input_points=input_points, |
|
input_labels=input_labels, |
|
input_boxes=input_boxes, |
|
input_masks=input_masks, |
|
) |
|
|
|
low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder( |
|
image_embeddings=image_embeddings, |
|
image_positional_embeddings=image_positional_embeddings, |
|
sparse_prompt_embeddings=sparse_embeddings, |
|
dense_prompt_embeddings=dense_embeddings, |
|
multimask_output=multimask_output, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
if not return_dict: |
|
output = (iou_predictions, low_res_masks) |
|
if output_hidden_states: |
|
output = output + (vision_hidden_states,) |
|
|
|
if output_attentions: |
|
output = output + (vision_attentions, mask_decoder_attentions) |
|
return output |
|
|
|
return TFSamImageSegmentationOutput( |
|
iou_scores=iou_predictions, |
|
pred_masks=low_res_masks, |
|
vision_hidden_states=vision_hidden_states, |
|
vision_attentions=vision_attentions, |
|
mask_decoder_attentions=mask_decoder_attentions, |
|
) |
|
|
|
def serving_output(self, output: TFSamImageSegmentationOutput) -> TFSamImageSegmentationOutput: |
|
hs = tf.convert_to_tensor(output.vision_hidden_states) if self.config.output_hidden_states else None |
|
attns = tf.convert_to_tensor(output.vision_attentions) if self.config.output_attentions else None |
|
|
|
return TFSamImageSegmentationOutput( |
|
iou_scores=output.iou_scores, |
|
pred_masks=output.pred_masks, |
|
vision_hidden_states=hs if self.config.output_hidden_states else None, |
|
vision_attentions=attns if self.config.output_attentions else None, |
|
mask_decoder_attentions=output.mask_decoder_attentions if self.config.output_attentions else None, |
|
) |
|
|