# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F class GumbelVectorQuantizer(nn.Module): def __init__( self, dim, num_vars, temp, groups, combine_groups, vq_dim, time_first, activation=nn.GELU(), weight_proj_depth=1, weight_proj_factor=1, ): """Vector quantization using gumbel softmax Args: dim: input dimension (channels) num_vars: number of quantized vectors per group temp: temperature for training. this should be a tuple of 3 elements: (start, stop, decay factor) groups: number of groups for vector quantization combine_groups: whether to use the vectors for all groups vq_dim: dimensionality of the resulting quantized vector time_first: if true, expect input in BxTxC format, otherwise in BxCxT activation: what activation to use (should be a module). this is only used if weight_proj_depth is > 1 weight_proj_depth: number of layers (with activation in between) to project input before computing logits weight_proj_factor: this is used only if weight_proj_depth is > 1. scales the inner dimensionality of projections by this factor """ super().__init__() self.groups = groups self.combine_groups = combine_groups self.input_dim = dim self.num_vars = num_vars self.time_first = time_first assert ( vq_dim % groups == 0 ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation" var_dim = vq_dim // groups num_groups = groups if not combine_groups else 1 self.vars = nn.Parameter(torch.FloatTensor(1, num_groups * num_vars, var_dim)) nn.init.uniform_(self.vars) if weight_proj_depth > 1: def block(input_dim, output_dim): return nn.Sequential(nn.Linear(input_dim, output_dim), activation) inner_dim = self.input_dim * weight_proj_factor self.weight_proj = nn.Sequential( *[ block(self.input_dim if i == 0 else inner_dim, inner_dim) for i in range(weight_proj_depth - 1) ], nn.Linear(inner_dim, groups * num_vars), ) else: self.weight_proj = nn.Linear(self.input_dim, groups * num_vars) nn.init.normal_(self.weight_proj.weight, mean=0, std=1) nn.init.zeros_(self.weight_proj.bias) if isinstance(temp, str): import ast temp = ast.literal_eval(temp) assert len(temp) == 3, f"{temp}, {len(temp)}" self.max_temp, self.min_temp, self.temp_decay = temp self.curr_temp = self.max_temp self.codebook_indices = None def set_num_updates(self, num_updates): self.curr_temp = max( self.max_temp * self.temp_decay ** num_updates, self.min_temp ) def get_codebook_indices(self): if self.codebook_indices is None: from itertools import product p = [range(self.num_vars)] * self.groups inds = list(product(*p)) self.codebook_indices = torch.tensor( inds, dtype=torch.long, device=self.vars.device ).flatten() if not self.combine_groups: self.codebook_indices = self.codebook_indices.view( self.num_vars ** self.groups, -1 ) for b in range(1, self.groups): self.codebook_indices[:, b] += self.num_vars * b self.codebook_indices = self.codebook_indices.flatten() return self.codebook_indices def codebook(self): indices = self.get_codebook_indices() return ( self.vars.squeeze(0) .index_select(0, indices) .view(self.num_vars ** self.groups, -1) ) def sample_from_codebook(self, b, n): indices = self.get_codebook_indices() indices = indices.view(-1, self.groups) cb_size = indices.size(0) assert ( n < cb_size ), f"sample size {n} is greater than size of codebook {cb_size}" sample_idx = torch.randint(low=0, high=cb_size, size=(b * n,)) indices = indices[sample_idx] z = self.vars.squeeze(0).index_select(0, indices.flatten()).view(b, n, -1) return z def to_codebook_index(self, indices): res = indices.new_full(indices.shape[:-1], 0) for i in range(self.groups): exponent = self.groups - i - 1 res += indices[..., i] * (self.num_vars ** exponent) return res def forward_idx(self, x): res = self.forward(x, produce_targets=True) return res["x"], res["targets"] def forward(self, x, produce_targets=False): result = {"num_vars": self.num_vars * self.groups} if not self.time_first: x = x.transpose(1, 2) bsz, tsz, fsz = x.shape x = x.reshape(-1, fsz) x = self.weight_proj(x) x = x.view(bsz * tsz * self.groups, -1) _, k = x.max(-1) hard_x = ( x.new_zeros(*x.shape) .scatter_(-1, k.view(-1, 1), 1.0) .view(bsz * tsz, self.groups, -1) ) hard_probs = torch.mean(hard_x.float(), dim=0) result["code_perplexity"] = torch.exp( -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) ).sum() avg_probs = torch.softmax( x.view(bsz * tsz, self.groups, -1).float(), dim=-1 ).mean(dim=0) result["prob_perplexity"] = torch.exp( -torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) ).sum() result["temp"] = self.curr_temp if self.training: x = F.gumbel_softmax(x.float(), tau=self.curr_temp, hard=True).type_as(x) else: x = hard_x x = x.view(bsz * tsz, -1) vars = self.vars if self.combine_groups: vars = vars.repeat(1, self.groups, 1) if produce_targets: result["targets"] = ( x.view(bsz * tsz * self.groups, -1) .argmax(dim=-1) .view(bsz, tsz, self.groups) .detach() ) x = x.unsqueeze(-1) * vars x = x.view(bsz * tsz, self.groups, self.num_vars, -1) x = x.sum(-2) x = x.view(bsz, tsz, -1) if not self.time_first: x = x.transpose(1, 2) # BTC -> BCT result["x"] = x return result