import math import torch from .pair_wise_distance import PairwiseDistFunction from .sparse_utils import naive_sparse_bmm def calc_init_centroid(images, num_spixels_width, num_spixels_height): """ calculate initial superpixels Args: images: torch.Tensor A Tensor of shape (B, C, H, W) spixels_width: int initial superpixel width spixels_height: int initial superpixel height Return: centroids: torch.Tensor A Tensor of shape (B, C, H * W) init_label_map: torch.Tensor A Tensor of shape (B, H * W) num_spixels_width: int A number of superpixels in each column num_spixels_height: int A number of superpixels int each raw """ batchsize, channels, height, width = images.shape device = images.device centroids = torch.nn.functional.adaptive_avg_pool2d(images, (num_spixels_height, num_spixels_width)) with torch.no_grad(): num_spixels = num_spixels_width * num_spixels_height labels = torch.arange(num_spixels, device=device).reshape(1, 1, *centroids.shape[-2:]).type_as(centroids) init_label_map = torch.nn.functional.interpolate(labels, size=(height, width), mode="nearest") init_label_map = init_label_map.repeat(batchsize, 1, 1, 1) init_label_map = init_label_map.reshape(batchsize, -1) centroids = centroids.reshape(batchsize, channels, -1) return centroids, init_label_map @torch.no_grad() def get_abs_indices(init_label_map, num_spixels_width): b, n_pixel = init_label_map.shape device = init_label_map.device r = torch.arange(-1, 2.0, device=device) relative_spix_indices = torch.cat([r - num_spixels_width, r, r + num_spixels_width], 0) abs_pix_indices = torch.arange(n_pixel, device=device)[None, None].repeat(b, 9, 1).reshape(-1).long() abs_spix_indices = (init_label_map[:, None] + relative_spix_indices[None, :, None]).reshape(-1).long() abs_batch_indices = torch.arange(b, device=device)[:, None, None].repeat(1, 9, n_pixel).reshape(-1).long() return torch.stack([abs_batch_indices, abs_spix_indices, abs_pix_indices], 0) @torch.no_grad() def get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width): relative_label = affinity_matrix.max(1)[1] r = torch.arange(-1, 2.0, device=affinity_matrix.device) relative_spix_indices = torch.cat([r - num_spixels_width, r, r + num_spixels_width], 0) label = init_label_map + relative_spix_indices[relative_label] return label.long() @torch.no_grad() def sparse_ssn_iter(pixel_features, num_spixels, n_iter): """ computing assignment iterations with sparse matrix detailed process is in Algorithm 1, line 2 - 6 NOTE: this function does NOT guarantee the backward computation. Args: pixel_features: torch.Tensor A Tensor of shape (B, C, H, W) num_spixels: int A number of superpixels n_iter: int A number of iterations return_hard_label: bool return hard assignment or not """ height, width = pixel_features.shape[-2:] num_spixels_width = int(math.sqrt(num_spixels * width / height)) num_spixels_height = int(math.sqrt(num_spixels * height / width)) spixel_features, init_label_map = \ calc_init_centroid(pixel_features, num_spixels_width, num_spixels_height) abs_indices = get_abs_indices(init_label_map, num_spixels_width) pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1) permuted_pixel_features = pixel_features.permute(0, 2, 1) for _ in range(n_iter): dist_matrix = PairwiseDistFunction.apply( pixel_features, spixel_features, init_label_map, num_spixels_width, num_spixels_height) affinity_matrix = (-dist_matrix).softmax(1) reshaped_affinity_matrix = affinity_matrix.reshape(-1) mask = (abs_indices[1] >= 0) * (abs_indices[1] < num_spixels) sparse_abs_affinity = torch.sparse_coo_tensor(abs_indices[:, mask], reshaped_affinity_matrix[mask]) spixel_features = naive_sparse_bmm(sparse_abs_affinity, permuted_pixel_features) \ / (torch.sparse.sum(sparse_abs_affinity, 2).to_dense()[..., None] + 1e-16) spixel_features = spixel_features.permute(0, 2, 1) hard_labels = get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width) return sparse_abs_affinity, hard_labels, spixel_features def ssn_iter(pixel_features, num_spixels, n_iter): """ computing assignment iterations detailed process is in Algorithm 1, line 2 - 6 Args: pixel_features: torch.Tensor A Tensor of shape (B, C, H, W) num_spixels: int A number of superpixels n_iter: int A number of iterations return_hard_label: bool return hard assignment or not """ height, width = pixel_features.shape[-2:] num_spixels_width = int(math.sqrt(num_spixels * width / height)) num_spixels_height = int(math.sqrt(num_spixels * height / width)) # spixel_features: 10 * 202 * 64 # init_label_map: 10 * 40000 spixel_features, init_label_map = \ calc_init_centroid(pixel_features, num_spixels_width, num_spixels_height) # get indices of the 9 neighbors abs_indices = get_abs_indices(init_label_map, num_spixels_width) # 10 * 202 * 40000 pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1) # 10 * 40000 * 202 permuted_pixel_features = pixel_features.permute(0, 2, 1).contiguous() for _ in range(n_iter): # 10 * 9 * 40000 dist_matrix = PairwiseDistFunction.apply( pixel_features, spixel_features, init_label_map, num_spixels_width, num_spixels_height) affinity_matrix = (-dist_matrix).softmax(1) reshaped_affinity_matrix = affinity_matrix.reshape(-1) mask = (abs_indices[1] >= 0) * (abs_indices[1] < num_spixels) # 10 * 64 * 40000 sparse_abs_affinity = torch.sparse_coo_tensor(abs_indices[:, mask], reshaped_affinity_matrix[mask]) abs_affinity = sparse_abs_affinity.to_dense().contiguous() spixel_features = torch.bmm(abs_affinity, permuted_pixel_features) \ / (abs_affinity.sum(2, keepdim=True) + 1e-16) spixel_features = spixel_features.permute(0, 2, 1).contiguous() hard_labels = get_hard_abs_labels(affinity_matrix, init_label_map, num_spixels_width) return abs_affinity, hard_labels, spixel_features def ssn_iter2(pixel_features, num_spixels, n_iter, init_spixel_features, temp = 1): """ computing assignment iterations for second layer Args: pixel_features: torch.Tensor A Tensor of shape (B, C, N) num_spixels: int A number of superpixels init_spixel_features: A Tensor of shape (B, C, num_spixels) """ spixel_features = init_spixel_features.permute(0, 2, 1) pixel_features = pixel_features.permute(0, 2, 1) for _ in range(n_iter): # compute distance to all spixel_features dist = torch.cdist(pixel_features, spixel_features) # B, N, num_spixels aff = (-dist * temp).softmax(-1).permute(0, 2, 1) # B, num_spixels, N # compute new superpixels centers spixel_features = torch.bmm(aff, pixel_features) / (aff.sum(2, keepdim=True) + 1e-6) # B, num_spixels, C hard_labels = torch.argmax(aff, dim = 1) return aff, hard_labels, spixel_features