TextureScraping / libs /nnutils.py
sunshineatnoon
Add application file
1b2a9b1
raw
history blame
No virus
4.13 kB
""" Network utils
- poolfeat: aggregate superpixel features from pixel features
- upfeat: reconstruction pixel features from superpixel features
- quantize: quantization features given a codebook
"""
import torch
def poolfeat(input, prob, avg = True):
""" A function to aggregate superpixel features from pixel features
Args:
input (tensor): input feature tensor.
prob (tensor): one-hot superpixel segmentation.
avg (bool, optional): average or sum the pixel features to get superpixel features
Returns:
cluster_feat (tensor): the superpixel features
Shape:
input: (B, C, H, W)
prob: (B, N, H, W)
cluster_feat: (B, N, C)
"""
B, C, H, W = input.shape
B, N, H, W = prob.shape
prob_flat = prob.view(B, N, -1)
input_flat = input.view(B, C, -1)
cluster_feat = torch.matmul(prob_flat, input_flat.permute(0, 2, 1))
if avg:
cluster_sum = torch.sum(prob_flat, dim = -1).view(B, N , 1)
cluster_feat = cluster_feat / (cluster_sum + 1e-8)
return cluster_feat
def upfeat(input, prob):
""" A function to compute pixel features from superpixel features
Args:
input (tensor): superpixel feature tensor.
prob (tensor): one-hot superpixel segmentation.
Returns:
reconstr_feat (tensor): the pixel features.
Shape:
input: (B, N, C)
prob: (B, N, H, W)
reconstr_feat: (B, C, H, W)
"""
B, N, H, W = prob.shape
prob_flat = prob.view(B, N, -1)
reconstr_feat = torch.matmul(prob_flat.permute(0, 2, 1), input)
reconstr_feat = reconstr_feat.view(B, H, W, -1).permute(0, 3, 1, 2)
return reconstr_feat
def quantize(z, embedding, beta = 0.25):
"""
Inputs the output of the encoder network z and maps it to a discrete
one-hot vector that is the index of the closest embedding vector e_j
Args:
z (tensor): features from the encoder network
embedding (tensor): codebook
beta (scalar, optional): commit loss weight
Returns:
z_q: quantized features
loss: vq loss + commit loss * beta
min_encodings: quantization assignment one hot vector
min_encoding_indices: quantization assignment
Shape:
z: B, N, C
embedding: B, K, C
z_q: B, N, C
min_encodings: B, N, K
min_encoding_indices: B, N, 1
Note:
Adapted from https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py
"""
# B, 256, 32
if embedding.shape[0] == 1:
d = torch.sum(z ** 2, dim=2, keepdim=True) + torch.sum(embedding**2, dim=2) - 2 * torch.matmul(z, embedding.transpose(1, 2))
else:
ds = []
for i in range(embedding.shape[0]):
z_i = z[i:i+1]
embedding_i = embedding[i:i+1]
ds.append(torch.sum(z_i ** 2, dim=2, keepdim=True) + torch.sum(embedding_i**2, dim=2) - 2 * torch.matmul(z_i, embedding_i.transpose(1, 2)))
d = torch.cat(ds)
## could possible replace this here
# #\start...
# find closest encodings
min_encoding_indices = torch.argmin(d, dim=2).unsqueeze(2) # B, 256, 1
#min_encodings = torch.zeros(
# min_encoding_indices.shape[0], self.n_e).to(z)
#min_encodings.scatter_(1, min_encoding_indices, 1)
n_e = embedding.shape[1] # 32
min_encodings = torch.zeros(z.shape[0], z.shape[1], n_e).to(z)
min_encodings.scatter_(2, min_encoding_indices, 1)
# dtype min encodings: torch.float32
# min_encodings shape: torch.Size([2048, 512])
# min_encoding_indices.shape: torch.Size([2048, 1])
# get quantized latent vectors
z_q = torch.matmul(min_encodings, embedding).view(z.shape)
#.........\end
# with:
# .........\start
#min_encoding_indices = torch.argmin(d, dim=1)
#z_q = self.embedding(min_encoding_indices)
# ......\end......... (TODO)
# compute loss for embedding
loss = torch.mean((z_q.detach()-z)**2) + beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
return z_q, loss, (min_encodings, min_encoding_indices, d)