ver217 commited on
Commit
0b09127
1 Parent(s): 387aa7a

[feature] support decode from embeddings

Browse files
Files changed (1) hide show
  1. modeling_vqvae.py +17 -0
modeling_vqvae.py CHANGED
@@ -78,6 +78,11 @@ class VQVAE(PreTrainedModel):
78
  h = self.post_vq_conv(shift_dim(h, -1, 1))
79
  return self.decoder(h)
80
 
 
 
 
 
 
81
  def forward(self, x):
82
  z = self.pre_vq_conv(self.encoder(x))
83
  vq_output = self.codebook(z)
@@ -159,6 +164,18 @@ class Codebook(nn.Module):
159
  self.z_avg.data.copy_(_k_rand)
160
  self.N.data.copy_(torch.ones(self.n_codes))
161
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def forward(self, z):
163
  # z: [b, c, t, h, w]
164
  if self._need_init and self.training:
 
78
  h = self.post_vq_conv(shift_dim(h, -1, 1))
79
  return self.decoder(h)
80
 
81
+ def decode_from_embeddings(self, embeddings):
82
+ # embeddings: [b, c, t, h, w]
83
+ encodings = self.codebook.search_indices(embeddings)
84
+ return self.decode(encodings)
85
+
86
  def forward(self, x):
87
  z = self.pre_vq_conv(self.encoder(x))
88
  vq_output = self.codebook(z)
 
164
  self.z_avg.data.copy_(_k_rand)
165
  self.N.data.copy_(torch.ones(self.n_codes))
166
 
167
+ def search_indices(self, z):
168
+ # z: [b, c, t, h, w]
169
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
170
+ distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \
171
+ - 2 * flat_inputs @ self.embeddings.t() \
172
+ + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
173
+
174
+ encoding_indices = torch.argmin(distances, dim=1)
175
+ encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:])
176
+ return encoding_indices
177
+
178
+
179
  def forward(self, z):
180
  # z: [b, c, t, h, w]
181
  if self._need_init and self.training: