""" Network utils - poolfeat: aggregate superpixel features from pixel features - upfeat: reconstruction pixel features from superpixel features - quantize: quantization features given a codebook """ import torch def poolfeat(input, prob, avg = True): """ A function to aggregate superpixel features from pixel features Args: input (tensor): input feature tensor. prob (tensor): one-hot superpixel segmentation. avg (bool, optional): average or sum the pixel features to get superpixel features Returns: cluster_feat (tensor): the superpixel features Shape: input: (B, C, H, W) prob: (B, N, H, W) cluster_feat: (B, N, C) """ B, C, H, W = input.shape B, N, H, W = prob.shape prob_flat = prob.view(B, N, -1) input_flat = input.view(B, C, -1) cluster_feat = torch.matmul(prob_flat, input_flat.permute(0, 2, 1)) if avg: cluster_sum = torch.sum(prob_flat, dim = -1).view(B, N , 1) cluster_feat = cluster_feat / (cluster_sum + 1e-8) return cluster_feat def upfeat(input, prob): """ A function to compute pixel features from superpixel features Args: input (tensor): superpixel feature tensor. prob (tensor): one-hot superpixel segmentation. Returns: reconstr_feat (tensor): the pixel features. Shape: input: (B, N, C) prob: (B, N, H, W) reconstr_feat: (B, C, H, W) """ B, N, H, W = prob.shape prob_flat = prob.view(B, N, -1) reconstr_feat = torch.matmul(prob_flat.permute(0, 2, 1), input) reconstr_feat = reconstr_feat.view(B, H, W, -1).permute(0, 3, 1, 2) return reconstr_feat def quantize(z, embedding, beta = 0.25): """ Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the closest embedding vector e_j Args: z (tensor): features from the encoder network embedding (tensor): codebook beta (scalar, optional): commit loss weight Returns: z_q: quantized features loss: vq loss + commit loss * beta min_encodings: quantization assignment one hot vector min_encoding_indices: quantization assignment Shape: z: B, N, C embedding: B, K, C z_q: B, N, C min_encodings: B, N, K min_encoding_indices: B, N, 1 Note: Adapted from https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py """ # B, 256, 32 if embedding.shape[0] == 1: d = torch.sum(z ** 2, dim=2, keepdim=True) + torch.sum(embedding**2, dim=2) - 2 * torch.matmul(z, embedding.transpose(1, 2)) else: ds = [] for i in range(embedding.shape[0]): z_i = z[i:i+1] embedding_i = embedding[i:i+1] ds.append(torch.sum(z_i ** 2, dim=2, keepdim=True) + torch.sum(embedding_i**2, dim=2) - 2 * torch.matmul(z_i, embedding_i.transpose(1, 2))) d = torch.cat(ds) ## could possible replace this here # #\start... # find closest encodings min_encoding_indices = torch.argmin(d, dim=2).unsqueeze(2) # B, 256, 1 #min_encodings = torch.zeros( # min_encoding_indices.shape[0], self.n_e).to(z) #min_encodings.scatter_(1, min_encoding_indices, 1) n_e = embedding.shape[1] # 32 min_encodings = torch.zeros(z.shape[0], z.shape[1], n_e).to(z) min_encodings.scatter_(2, min_encoding_indices, 1) # dtype min encodings: torch.float32 # min_encodings shape: torch.Size([2048, 512]) # min_encoding_indices.shape: torch.Size([2048, 1]) # get quantized latent vectors z_q = torch.matmul(min_encodings, embedding).view(z.shape) #.........\end # with: # .........\start #min_encoding_indices = torch.argmin(d, dim=1) #z_q = self.embedding(min_encoding_indices) # ......\end......... (TODO) # compute loss for embedding loss = torch.mean((z_q.detach()-z)**2) + beta * torch.mean((z_q - z.detach()) ** 2) # preserve gradients z_q = z + (z_q - z).detach() return z_q, loss, (min_encodings, min_encoding_indices, d)