# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """ Miscellaneous utility functions """ import torch def cat(tensors, dim=0): """ Efficient version of torch.cat that avoids a copy if there is only a single element in a list """ assert isinstance(tensors, (list, tuple)) if len(tensors) == 1: return tensors[0] return torch.cat(tensors, dim) def permute_and_flatten(layer, N, A, C, H, W): layer = layer.view(N, -1, C, H, W) layer = layer.permute(0, 3, 4, 1, 2) layer = layer.reshape(N, -1, C) return layer def concat_box_prediction_layers(box_regression, box_cls=None, token_logits=None): box_regression_flattened = [] box_cls_flattened = [] token_logit_flattened = [] # for each feature level, permute the outputs to make them be in the # same format as the labels. Note that the labels are computed for # all feature levels concatenated, so we keep the same representation # for the objectness and the box_regression for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression): N, AxC, H, W = box_cls_per_level.shape Ax4 = box_regression_per_level.shape[1] A = Ax4 // 4 C = AxC // A box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W) box_cls_flattened.append(box_cls_per_level) box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W) box_regression_flattened.append(box_regression_per_level) if token_logits is not None: for token_logit_per_level in token_logits: N, AXT, H, W = token_logit_per_level.shape T = AXT // A token_logit_per_level = permute_and_flatten(token_logit_per_level, N, A, T, H, W) token_logit_flattened.append(token_logit_per_level) # concatenate on the first dimension (representing the feature levels), to # take into account the way the labels were generated (with all feature maps # being concatenated as well) box_cls = cat(box_cls_flattened, dim=1).reshape(-1, C) box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4) token_logits_stacked = None if token_logits is not None: # stacked token_logits_stacked = cat(token_logit_flattened, dim=1) return box_regression, box_cls, token_logits_stacked def round_channels(channels, divisor=8): rounded_channels = max(int(channels + divisor / 2.0) // divisor * divisor, divisor) if float(rounded_channels) < 0.9 * channels: rounded_channels += divisor return rounded_channels