Spaces:
Runtime error
Runtime error
import math | |
import torch | |
from .pair_wise_distance import PairwiseDistFunction | |
from .sparse_utils import naive_sparse_bmm | |
def calc_init_centroid(images, num_spixels_width, num_spixels_height): | |
""" | |
calculate initial superpixels | |
Args: | |
images: torch.Tensor | |
A Tensor of shape (B, C, H, W) | |
spixels_width: int | |
initial superpixel width | |
spixels_height: int | |
initial superpixel height | |
Return: | |
centroids: torch.Tensor | |
A Tensor of shape (B, C, H * W) | |
init_label_map: torch.Tensor | |
A Tensor of shape (B, H * W) | |
num_spixels_width: int | |
A number of superpixels in each column | |
num_spixels_height: int | |
A number of superpixels int each raw | |
""" | |
batchsize, channels, height, width = images.shape | |
device = images.device | |
centroids = torch.nn.functional.adaptive_avg_pool2d(images, (num_spixels_height, num_spixels_width)) | |
with torch.no_grad(): | |
num_spixels = num_spixels_width * num_spixels_height | |
labels = torch.arange(num_spixels, device=device).reshape(1, 1, *centroids.shape[-2:]).type_as(centroids) | |
init_label_map = torch.nn.functional.interpolate(labels, size=(height, width), mode="nearest") | |
init_label_map = init_label_map.repeat(batchsize, 1, 1, 1) | |
init_label_map = init_label_map.reshape(batchsize, -1) | |
centroids = centroids.reshape(batchsize, channels, -1) | |
return centroids, init_label_map | |
def get_abs_indices(init_label_map, num_spixels_width): | |
b, n_pixel = init_label_map.shape | |
device = init_label_map.device | |
r = torch.arange(-1, 2.0, device=device) | |
relative_spix_indices = torch.cat([r - num_spixels_width, r, r + num_spixels_width], 0) | |
abs_pix_indices = torch.arange(n_pixel, device=device)[None, None].repeat(b, 9, 1).reshape(-1).long() | |
abs_spix_indices = (init_label_map[:, None] + relative_spix_indices[None, :, None]).reshape(-1).long() | |
abs_batch_indices = torch.arange(b, device=device)[:, None, None].repeat(1, 9, n_pixel).reshape(-1).long() | |
return torch.stack([abs_batch_indices, abs_spix_indices, abs_pix_indices], 0) | |
def get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width): | |
relative_label = affinity_matrix.max(1)[1] | |
r = torch.arange(-1, 2.0, device=affinity_matrix.device) | |
relative_spix_indices = torch.cat([r - num_spixels_width, r, r + num_spixels_width], 0) | |
label = init_label_map + relative_spix_indices[relative_label] | |
return label.long() | |
def sparse_ssn_iter(pixel_features, num_spixels, n_iter): | |
""" | |
computing assignment iterations with sparse matrix | |
detailed process is in Algorithm 1, line 2 - 6 | |
NOTE: this function does NOT guarantee the backward computation. | |
Args: | |
pixel_features: torch.Tensor | |
A Tensor of shape (B, C, H, W) | |
num_spixels: int | |
A number of superpixels | |
n_iter: int | |
A number of iterations | |
return_hard_label: bool | |
return hard assignment or not | |
""" | |
height, width = pixel_features.shape[-2:] | |
num_spixels_width = int(math.sqrt(num_spixels * width / height)) | |
num_spixels_height = int(math.sqrt(num_spixels * height / width)) | |
spixel_features, init_label_map = \ | |
calc_init_centroid(pixel_features, num_spixels_width, num_spixels_height) | |
abs_indices = get_abs_indices(init_label_map, num_spixels_width) | |
pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1) | |
permuted_pixel_features = pixel_features.permute(0, 2, 1) | |
for _ in range(n_iter): | |
dist_matrix = PairwiseDistFunction.apply( | |
pixel_features, spixel_features, init_label_map, num_spixels_width, num_spixels_height) | |
affinity_matrix = (-dist_matrix).softmax(1) | |
reshaped_affinity_matrix = affinity_matrix.reshape(-1) | |
mask = (abs_indices[1] >= 0) * (abs_indices[1] < num_spixels) | |
sparse_abs_affinity = torch.sparse_coo_tensor(abs_indices[:, mask], reshaped_affinity_matrix[mask]) | |
spixel_features = naive_sparse_bmm(sparse_abs_affinity, permuted_pixel_features) \ | |
/ (torch.sparse.sum(sparse_abs_affinity, 2).to_dense()[..., None] + 1e-16) | |
spixel_features = spixel_features.permute(0, 2, 1) | |
hard_labels = get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width) | |
return sparse_abs_affinity, hard_labels, spixel_features | |
def ssn_iter(pixel_features, num_spixels, n_iter): | |
""" | |
computing assignment iterations | |
detailed process is in Algorithm 1, line 2 - 6 | |
Args: | |
pixel_features: torch.Tensor | |
A Tensor of shape (B, C, H, W) | |
num_spixels: int | |
A number of superpixels | |
n_iter: int | |
A number of iterations | |
return_hard_label: bool | |
return hard assignment or not | |
""" | |
height, width = pixel_features.shape[-2:] | |
num_spixels_width = int(math.sqrt(num_spixels * width / height)) | |
num_spixels_height = int(math.sqrt(num_spixels * height / width)) | |
# spixel_features: 10 * 202 * 64 | |
# init_label_map: 10 * 40000 | |
spixel_features, init_label_map = \ | |
calc_init_centroid(pixel_features, num_spixels_width, num_spixels_height) | |
# get indices of the 9 neighbors | |
abs_indices = get_abs_indices(init_label_map, num_spixels_width) | |
# 10 * 202 * 40000 | |
pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1) | |
# 10 * 40000 * 202 | |
permuted_pixel_features = pixel_features.permute(0, 2, 1).contiguous() | |
for _ in range(n_iter): | |
# 10 * 9 * 40000 | |
dist_matrix = PairwiseDistFunction.apply( | |
pixel_features, spixel_features, init_label_map, num_spixels_width, num_spixels_height) | |
affinity_matrix = (-dist_matrix).softmax(1) | |
reshaped_affinity_matrix = affinity_matrix.reshape(-1) | |
mask = (abs_indices[1] >= 0) * (abs_indices[1] < num_spixels) | |
# 10 * 64 * 40000 | |
sparse_abs_affinity = torch.sparse_coo_tensor(abs_indices[:, mask], reshaped_affinity_matrix[mask]) | |
abs_affinity = sparse_abs_affinity.to_dense().contiguous() | |
spixel_features = torch.bmm(abs_affinity, permuted_pixel_features) \ | |
/ (abs_affinity.sum(2, keepdim=True) + 1e-16) | |
spixel_features = spixel_features.permute(0, 2, 1).contiguous() | |
hard_labels = get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width) | |
return abs_affinity, hard_labels, spixel_features | |
def ssn_iter2(pixel_features, num_spixels, n_iter, init_spixel_features, temp = 1): | |
""" | |
computing assignment iterations for second layer | |
Args: | |
pixel_features: torch.Tensor | |
A Tensor of shape (B, C, N) | |
num_spixels: int | |
A number of superpixels | |
init_spixel_features: | |
A Tensor of shape (B, C, num_spixels) | |
""" | |
spixel_features = init_spixel_features.permute(0, 2, 1) | |
pixel_features = pixel_features.permute(0, 2, 1) | |
for _ in range(n_iter): | |
# compute distance to all spixel_features | |
dist = torch.cdist(pixel_features, spixel_features) # B, N, num_spixels | |
aff = (-dist * temp).softmax(-1).permute(0, 2, 1) # B, num_spixels, N | |
# compute new superpixels centers | |
spixel_features = torch.bmm(aff, pixel_features) / (aff.sum(2, keepdim=True) + 1e-6) # B, num_spixels, C | |
hard_labels = torch.argmax(aff, dim = 1) | |
return aff, hard_labels, spixel_features | |