HakimAiV2 / modeling /utils /interactive.py
scdrand23's picture
not working version
814a594
raw
history blame
1.65 kB
import os
import copy
import math
import torch
from torch import nn, Tensor
import torch.nn.functional as F
def rand_sample(x, divisor, max_len):
# non_zero_pos_point = [rand_sample((m.nonzero()/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
if len(x.nonzero()) == 0:
return x.nonzero().t()
non_zero_point_index = (x.nonzero()/divisor).t()
mask_ids = non_zero_point_index[0].unique().long()
# compute probability for each samle
probs = torch.zeros_like(non_zero_point_index[0])
for idx in mask_ids:
prob = 1./(len(mask_ids)*((non_zero_point_index[0:1]==idx).sum()))
probs[non_zero_point_index[0]==idx] = prob
indices = torch.multinomial(probs, num_samples=min(max_len, len(probs)), replacement=False).sort()[0]
non_zero_point_index = non_zero_point_index[:,indices]
return non_zero_point_index # [n, 512]
def rand_sample_plain(x, max_len):
if x.shape[1] <= max_len:
return x
else:
rand_idx = torch.randperm(x.shape[1])[:max_len]
return x[:,rand_idx]
def prepare_features(x, num_feature_levels, pe_layer, input_proj, level_embed):
src = []
pos = []
size_list = []
# disable mask, it does not affect performance
for i in range(num_feature_levels):
size_list.append(x[i].shape[-2:])
pos.append(pe_layer(x[i], None).flatten(2))
src.append(input_proj[i](x[i]).flatten(2) + level_embed.weight[i][None, :, None])
# flatten NxCxHxW to HWxNxC
pos[-1] = pos[-1].permute(2, 0, 1)
src[-1] = src[-1].permute(2, 0, 1)
return src, pos, size_list