zdou0830's picture
desco
749745d
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
Miscellaneous utility functions
"""
import torch
def cat(tensors, dim=0):
"""
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
"""
assert isinstance(tensors, (list, tuple))
if len(tensors) == 1:
return tensors[0]
return torch.cat(tensors, dim)
def permute_and_flatten(layer, N, A, C, H, W):
layer = layer.view(N, -1, C, H, W)
layer = layer.permute(0, 3, 4, 1, 2)
layer = layer.reshape(N, -1, C)
return layer
def concat_box_prediction_layers(box_regression, box_cls=None, token_logits=None):
box_regression_flattened = []
box_cls_flattened = []
token_logit_flattened = []
# for each feature level, permute the outputs to make them be in the
# same format as the labels. Note that the labels are computed for
# all feature levels concatenated, so we keep the same representation
# for the objectness and the box_regression
for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
N, AxC, H, W = box_cls_per_level.shape
Ax4 = box_regression_per_level.shape[1]
A = Ax4 // 4
C = AxC // A
box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
box_cls_flattened.append(box_cls_per_level)
box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W)
box_regression_flattened.append(box_regression_per_level)
if token_logits is not None:
for token_logit_per_level in token_logits:
N, AXT, H, W = token_logit_per_level.shape
T = AXT // A
token_logit_per_level = permute_and_flatten(token_logit_per_level, N, A, T, H, W)
token_logit_flattened.append(token_logit_per_level)
# concatenate on the first dimension (representing the feature levels), to
# take into account the way the labels were generated (with all feature maps
# being concatenated as well)
box_cls = cat(box_cls_flattened, dim=1).reshape(-1, C)
box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4)
token_logits_stacked = None
if token_logits is not None:
# stacked
token_logits_stacked = cat(token_logit_flattened, dim=1)
return box_regression, box_cls, token_logits_stacked
def round_channels(channels, divisor=8):
rounded_channels = max(int(channels + divisor / 2.0) // divisor * divisor, divisor)
if float(rounded_channels) < 0.9 * channels:
rounded_channels += divisor
return rounded_channels