Spaces:
Sleeping
Sleeping
File size: 5,657 Bytes
bd9da36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from sam2.modeling.sam2_utils import DropPath, LayerNorm2d, get_clones
class MaskDownSampler(nn.Module):
"""
Progressively downsample a mask by total_stride, each time by stride.
Note that LayerNorm is applied per *token*, like in ViT.
With each downsample (by a factor stride**2), channel capacity increases by the same factor.
In the end, we linearly project to embed_dim channels.
"""
def __init__(
self,
embed_dim=256,
kernel_size=4,
stride=4,
padding=0,
total_stride=16,
activation=nn.GELU,
):
super().__init__()
num_layers = int(math.log2(total_stride) // math.log2(stride))
assert stride**num_layers == total_stride
self.encoder = nn.Sequential()
mask_in_chans, mask_out_chans = 1, 1
for _ in range(num_layers):
mask_out_chans = mask_in_chans * (stride**2)
self.encoder.append(
nn.Conv2d(
mask_in_chans,
mask_out_chans,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
)
self.encoder.append(LayerNorm2d(mask_out_chans))
self.encoder.append(activation())
mask_in_chans = mask_out_chans
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
def forward(self, x):
return self.encoder(x)
# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
class CXBlock(nn.Module):
r"""ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(
self,
dim,
kernel_size=7,
padding=3,
drop_path=0.0,
layer_scale_init_value=1e-6,
use_dwconv=True,
):
super().__init__()
self.dwconv = nn.Conv2d(
dim,
dim,
kernel_size=kernel_size,
padding=padding,
groups=dim if use_dwconv else 1,
) # depthwise conv
self.norm = LayerNorm2d(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, 4 * dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
if layer_scale_init_value > 0
else None
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
input = x
x = self.dwconv(x)
x = self.norm(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
class Fuser(nn.Module):
def __init__(self, layer, num_layers, dim=None, input_projection=False):
super().__init__()
self.proj = nn.Identity()
self.layers = get_clones(layer, num_layers)
if input_projection:
assert dim is not None
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
def forward(self, x):
# normally x: (N, C, H, W)
x = self.proj(x)
for layer in self.layers:
x = layer(x)
return x
class MemoryEncoder(nn.Module):
def __init__(
self,
out_dim,
mask_downsampler,
fuser,
position_encoding,
in_dim=256, # in_dim of pix_feats
):
super().__init__()
self.mask_downsampler = mask_downsampler
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
self.fuser = fuser
self.position_encoding = position_encoding
self.out_proj = nn.Identity()
if out_dim != in_dim:
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
def forward(
self,
pix_feat: torch.Tensor,
masks: torch.Tensor,
skip_mask_sigmoid: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
## Process masks
# sigmoid, so that less domain shift from gt masks which are bool
if not skip_mask_sigmoid:
masks = F.sigmoid(masks)
masks = self.mask_downsampler(masks)
## Fuse pix_feats and downsampled masks
# in case the visual features are on CPU, cast them to CUDA
pix_feat = pix_feat.to(masks.device)
x = self.pix_feat_proj(pix_feat)
x = x + masks
x = self.fuser(x)
x = self.out_proj(x)
pos = self.position_encoding(x).to(x.dtype)
return {"vision_features": x, "vision_pos_enc": [pos]}
|