File size: 4,130 Bytes
1b2a9b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
""" Network utils
 - poolfeat: aggregate superpixel features from pixel features
 - upfeat: reconstruction pixel features from superpixel features
 - quantize: quantization features given a codebook
"""
import torch

def poolfeat(input, prob, avg = True):
    """ A function to aggregate superpixel features from pixel features

    Args: 
      input (tensor): input feature tensor.
      prob (tensor): one-hot superpixel segmentation.
      avg (bool, optional): average or sum the pixel features to get superpixel features
    
    Returns:
      cluster_feat (tensor): the superpixel features
    
    Shape:
      input: (B, C, H, W)
      prob: (B, N, H, W)
      cluster_feat: (B, N, C)
    """ 
    B, C, H, W = input.shape
    B, N, H, W = prob.shape
    prob_flat = prob.view(B, N, -1)
    input_flat = input.view(B, C, -1)
    cluster_feat = torch.matmul(prob_flat, input_flat.permute(0, 2, 1))
    if avg:
        cluster_sum = torch.sum(prob_flat, dim = -1).view(B, N , 1)
        cluster_feat = cluster_feat / (cluster_sum + 1e-8)
    return cluster_feat

def upfeat(input, prob):
    """ A function to compute pixel features from superpixel features

    Args: 
      input (tensor): superpixel feature tensor.
      prob (tensor): one-hot superpixel segmentation.
    
    Returns:
      reconstr_feat (tensor): the pixel features.
    
    Shape:
      input: (B, N, C)
      prob: (B, N, H, W)
      reconstr_feat: (B, C, H, W)
    """
    B, N, H, W = prob.shape
    prob_flat = prob.view(B, N, -1)
    reconstr_feat = torch.matmul(prob_flat.permute(0, 2, 1), input)
    reconstr_feat = reconstr_feat.view(B, H, W, -1).permute(0, 3, 1, 2)

    return reconstr_feat

def quantize(z, embedding, beta = 0.25):
    """
    Inputs the output of the encoder network z and maps it to a discrete
    one-hot vector that is the index of the closest embedding vector e_j

    Args:
      z (tensor): features from the encoder network
      embedding (tensor): codebook
      beta (scalar, optional): commit loss weight
    
    Returns:
      z_q: quantized features
      loss: vq loss + commit loss * beta
      min_encodings: quantization assignment one hot vector
      min_encoding_indices: quantization assignment
    
    Shape:
      z: B, N, C
      embedding: B, K, C
      z_q: B, N, C
      min_encodings: B, N, K
      min_encoding_indices: B, N, 1
    
    Note:
      Adapted from https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py
    """

    # B, 256, 32
    if embedding.shape[0] == 1:
      d = torch.sum(z ** 2, dim=2, keepdim=True) + torch.sum(embedding**2, dim=2) - 2 * torch.matmul(z, embedding.transpose(1, 2))
    else:
      ds = []
      for i in range(embedding.shape[0]):
        z_i = z[i:i+1]
        embedding_i = embedding[i:i+1]
        ds.append(torch.sum(z_i ** 2, dim=2, keepdim=True) + torch.sum(embedding_i**2, dim=2) - 2 * torch.matmul(z_i, embedding_i.transpose(1, 2)))
      d = torch.cat(ds)

    ## could possible replace this here
    # #\start...
    # find closest encodings
    min_encoding_indices = torch.argmin(d, dim=2).unsqueeze(2) # B, 256, 1

    #min_encodings = torch.zeros(
    #    min_encoding_indices.shape[0], self.n_e).to(z)
    #min_encodings.scatter_(1, min_encoding_indices, 1)
    n_e = embedding.shape[1] # 32
    min_encodings = torch.zeros(z.shape[0], z.shape[1], n_e).to(z)
    min_encodings.scatter_(2, min_encoding_indices, 1)

    # dtype min encodings: torch.float32
    # min_encodings shape: torch.Size([2048, 512])
    # min_encoding_indices.shape: torch.Size([2048, 1])

    # get quantized latent vectors
    z_q = torch.matmul(min_encodings, embedding).view(z.shape)
    #.........\end

    # with:
    # .........\start
    #min_encoding_indices = torch.argmin(d, dim=1)
    #z_q = self.embedding(min_encoding_indices)
    # ......\end......... (TODO)

    # compute loss for embedding
    loss = torch.mean((z_q.detach()-z)**2) + beta * torch.mean((z_q - z.detach()) ** 2)

    # preserve gradients
    z_q = z + (z_q - z).detach()

    return z_q, loss, (min_encodings, min_encoding_indices, d)