Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved | |
| import math | |
| from typing import Optional | |
| import torch | |
| from torch import nn | |
| class PositionEmbeddingSine(nn.Module): | |
| """ | |
| This is a more standard version of the position embedding, very similar to the one | |
| used by the Attention is all you need paper, generalized to work on images. | |
| """ | |
| def __init__( | |
| self, | |
| num_pos_feats, | |
| temperature: int = 10000, | |
| normalize: bool = True, | |
| scale: Optional[float] = None, | |
| precompute_resolution: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| assert num_pos_feats % 2 == 0, "Expecting even model width" | |
| self.num_pos_feats = num_pos_feats // 2 | |
| self.temperature = temperature | |
| self.normalize = normalize | |
| if scale is not None and normalize is False: | |
| raise ValueError("normalize should be True if scale is passed") | |
| if scale is None: | |
| scale = 2 * math.pi | |
| self.scale = scale | |
| self.cache = {} | |
| # Precompute positional encodings under `precompute_resolution` to fill the cache | |
| # and avoid symbolic shape tracing errors in torch.compile in PyTorch 2.4 nightly. | |
| if precompute_resolution is not None: | |
| # We precompute pos enc for stride 4, 8, 16 and 32 to fill `self.cache`. | |
| precompute_sizes = [ | |
| (precompute_resolution // 4, precompute_resolution // 4), | |
| (precompute_resolution // 8, precompute_resolution // 8), | |
| (precompute_resolution // 16, precompute_resolution // 16), | |
| (precompute_resolution // 32, precompute_resolution // 32), | |
| ] | |
| for size in precompute_sizes: | |
| tensors = torch.zeros((1, 1) + size, device="cuda") | |
| self.forward(tensors) | |
| # further clone and detach it in the cache (just to be safe) | |
| self.cache[size] = self.cache[size].clone().detach() | |
| def _encode_xy(self, x, y): | |
| # The positions are expected to be normalized | |
| assert len(x) == len(y) and x.ndim == y.ndim == 1 | |
| x_embed = x * self.scale | |
| y_embed = y * self.scale | |
| dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) | |
| dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) | |
| pos_x = x_embed[:, None] / dim_t | |
| pos_y = y_embed[:, None] / dim_t | |
| pos_x = torch.stack( | |
| (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 | |
| ).flatten(1) | |
| pos_y = torch.stack( | |
| (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 | |
| ).flatten(1) | |
| return pos_x, pos_y | |
| def encode_boxes(self, x, y, w, h): | |
| pos_x, pos_y = self._encode_xy(x, y) | |
| pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) | |
| return pos | |
| encode = encode_boxes # Backwards compatibility | |
| def encode_points(self, x, y, labels): | |
| (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape | |
| assert bx == by and nx == ny and bx == bl and nx == nl | |
| pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) | |
| pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) | |
| pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) | |
| return pos | |
| def forward(self, x): | |
| cache_key = None | |
| cache_key = (x.shape[-2], x.shape[-1]) | |
| if cache_key in self.cache: | |
| return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) | |
| y_embed = ( | |
| torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) | |
| .view(1, -1, 1) | |
| .repeat(x.shape[0], 1, x.shape[-1]) | |
| ) | |
| x_embed = ( | |
| torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) | |
| .view(1, 1, -1) | |
| .repeat(x.shape[0], x.shape[-2], 1) | |
| ) | |
| if self.normalize: | |
| eps = 1e-6 | |
| y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale | |
| x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale | |
| dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) | |
| dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) | |
| pos_x = x_embed[:, :, :, None] / dim_t | |
| pos_y = y_embed[:, :, :, None] / dim_t | |
| pos_x = torch.stack( | |
| (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 | |
| ).flatten(3) | |
| pos_y = torch.stack( | |
| (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 | |
| ).flatten(3) | |
| pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) | |
| if cache_key is not None: | |
| self.cache[cache_key] = pos[0] | |
| return pos | |