Spaces:
Running
Running
File size: 8,027 Bytes
684943d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from einops import rearrange, repeat
from .blocks import EfficientUpdateFormer, CorrBlock
from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
from .modules import Mlp
class BaseTrackerPredictor(nn.Module):
def __init__(
self,
stride=1,
corr_levels=5,
corr_radius=4,
latent_dim=128,
hidden_size=384,
use_spaceatt=True,
depth=6,
max_scale=518,
predict_conf=True,
):
super(BaseTrackerPredictor, self).__init__()
"""
The base template to create a track predictor
Modified from https://github.com/facebookresearch/co-tracker/
and https://github.com/facebookresearch/vggsfm
"""
self.stride = stride
self.latent_dim = latent_dim
self.corr_levels = corr_levels
self.corr_radius = corr_radius
self.hidden_size = hidden_size
self.max_scale = max_scale
self.predict_conf = predict_conf
self.flows_emb_dim = latent_dim // 2
self.corr_mlp = Mlp(
in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
hidden_features=self.hidden_size,
out_features=self.latent_dim,
)
self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
space_depth = depth if use_spaceatt else 0
time_depth = depth
self.updateformer = EfficientUpdateFormer(
space_depth=space_depth,
time_depth=time_depth,
input_dim=self.transformer_dim,
hidden_size=self.hidden_size,
output_dim=self.latent_dim + 2,
mlp_ratio=4.0,
add_space_attn=use_spaceatt,
)
self.fmap_norm = nn.LayerNorm(self.latent_dim)
self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
# A linear layer to update track feats at each iteration
self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
if predict_conf:
self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
"""
query_points: B x N x 2, the number of batches, tracks, and xy
fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
note HH and WW is the size of feature maps instead of original images
"""
B, N, D = query_points.shape
B, S, C, HH, WW = fmaps.shape
assert D == 2, "Input points must be 2D coordinates"
# apply a layernorm to fmaps here
fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
fmaps = fmaps.permute(0, 1, 4, 2, 3)
# Scale the input query_points because we may downsample the images
# by down_ratio or self.stride
# e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
# its query_points should be query_points/4
if down_ratio > 1:
query_points = query_points / float(down_ratio)
query_points = query_points / float(self.stride)
# Init with coords as the query points
# It means the search will start from the position of query points at the reference frames
coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
# Sample/extract the features of the query points in the query frame
query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
# init track feats by query feats
track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
# back up the init coords
coords_backup = coords.clone()
fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
coord_preds = []
# Iterative Refinement
for _ in range(iters):
# Detach the gradients from the last iteration
# (in my experience, not very important for performance)
coords = coords.detach()
fcorrs = fcorr_fn.corr_sample(track_feats, coords)
corr_dim = fcorrs.shape[3]
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
fcorrs_ = self.corr_mlp(fcorrs_)
# Movement of current coords relative to query points
flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
# (In my trials, it is also okay to just add the flows_emb instead of concat)
flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
# Concatenate them as the input for the transformers
transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
# 2D positional embed
# TODO: this can be much simplified
pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
x = transformer_input + sampled_pos_emb
# Add the query ref token to the track feats
query_ref_token = torch.cat(
[self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
)
x = x + query_ref_token.to(x.device).to(x.dtype)
# B, N, S, C
x = rearrange(x, "(b n) s d -> b n s d", b=B)
# Compute the delta coordinates and delta track features
delta, _ = self.updateformer(x)
# BN, S, C
delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
delta_coords_ = delta[:, :, :2]
delta_feats_ = delta[:, :, 2:]
track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
# Update the track features
track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
# B x S x N x 2
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
# Force coord0 as query
# because we assume the query points should not be changed
coords[:, 0] = coords_backup[:, 0]
# The predicted tracks are in the original image scale
if down_ratio > 1:
coord_preds.append(coords * self.stride * down_ratio)
else:
coord_preds.append(coords * self.stride)
# B, S, N
vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
if apply_sigmoid:
vis_e = torch.sigmoid(vis_e)
if self.predict_conf:
conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
if apply_sigmoid:
conf_e = torch.sigmoid(conf_e)
else:
conf_e = None
if return_feat:
return coord_preds, vis_e, track_feats, query_track_feat, conf_e
else:
return coord_preds, vis_e, conf_e
|