| import torch |
| from typing import List, Tuple |
| from torch.nn import functional as F |
| from torch import distributed as tdist, nn as nn |
|
|
| from .quantizer import VectorQuantizer |
|
|
| def get_entropy_loss(latent_embed, codebook_embed, inv_entropy_tau): |
| E_dist = latent_embed.square().sum(dim=1, keepdim=True) + codebook_embed.square().sum(dim=1, keepdim=False) |
| E_dist.addmm_(latent_embed, codebook_embed.T, alpha=-2, beta=1) |
| logits = -E_dist.float().mul_(inv_entropy_tau) |
| |
| prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) |
| per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1)) |
| |
| avg_prob = prob.mean(dim=0) |
| log_avg_prob = torch.log(avg_prob + 1e-7) |
| codebook_entropy = (-avg_prob * log_avg_prob).sum() |
| |
| entropy_loss = per_sample_entropy - codebook_entropy |
| return entropy_loss |
|
|
|
|
| class NormalizedEmbedding(nn.Embedding): |
| def __init__(self, num_embeddings: int, embedding_dim: int): |
| super().__init__(num_embeddings=num_embeddings, embedding_dim=embedding_dim) |
| |
|
|
| def forward(self, idx): |
| return F.embedding( |
| idx, F.normalize(self.weight, dim=1), self.padding_idx, self.max_norm, |
| self.norm_type, self.scale_grad_by_freq, self.sparse |
| ) |
|
|
| def get_norm_weight(self): |
| return F.normalize(self.weight, dim=1) |
|
|
|
|
| class ResConv(nn.Conv2d): |
| def __init__(self, embed_dim, quant_resi): |
| ks = 3 if quant_resi < 0 else 1 |
| super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2) |
| self.resi_ratio = abs(quant_resi) |
|
|
| def forward(self, h_BChw): |
| return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio) |
|
|
| class VectorQuantizerMVQ(nn.Module): |
| def __init__( |
| self, |
| codebook_size, |
| token_size, |
| commitment_cost=0.25, |
| use_l2_norm=False, |
| |
| clustering_vq=False, |
| num_codebooks=16 |
| ): |
| super().__init__() |
| self.num_codebooks = num_codebooks |
| self.codebooks = nn.ModuleList() |
| for _ in range(num_codebooks): |
| codebook = VectorQuantizer( |
| codebook_size=codebook_size // num_codebooks, |
| token_size=token_size // num_codebooks, |
| commitment_cost=commitment_cost, |
| use_l2_norm=use_l2_norm, |
| clustering_vq=clustering_vq, |
| ) |
| self.codebooks.append(codebook) |
|
|
| def init_vocab(self, eini: float): |
| for codebook in self.codebooks: |
| codebook.init_vocab(eini) |
|
|
| def f_to_idx(self, features): |
| indices = [] |
| chunk_size = features.shape[-1] // self.num_codebooks |
| splited_features = features.split(chunk_size, dim=-1) |
| for i, codebook in enumerate(self.codebooks): |
| indices.append(codebook.f_to_idx(splited_features[i])) |
| indices = torch.stack(indices, dim=1) |
| return indices |
|
|
| def idx_to_f(self, indices): |
| assert indices.shape[1] == self.num_codebooks |
| latent_features = [] |
| for i, codebook in enumerate(self.codebooks): |
| sub_indices = indices[:, i].flatten(start_dim=1) |
| latent_feature = codebook.codebook(sub_indices) |
| latent_features.append(latent_feature) |
| latent_features = torch.cat(latent_features, dim=-1) |
| return latent_features |
|
|
| def get_codebook_entry(self, indices): |
| """Get codebook entries for multi-codebook indices. |
| |
| Args: |
| indices: Tensor of shape (N, num_codebooks) or (N, num_codebooks, H, W) |
| |
| Returns: |
| z_quantized: Quantized features |
| """ |
| if len(indices.shape) == 2: |
| |
| latent_features = [] |
| for i, codebook in enumerate(self.codebooks): |
| sub_indices = indices[:, i] |
| latent_feature = codebook.get_codebook_entry(sub_indices) |
| latent_features.append(latent_feature) |
| return torch.cat(latent_features, dim=-1) |
| elif len(indices.shape) == 4: |
| |
| batch_size, _, height, width = indices.shape |
| latent_features = [] |
| for i, codebook in enumerate(self.codebooks): |
| sub_indices = indices[:, i] |
| latent_feature = codebook.get_codebook_entry(sub_indices.flatten()) |
| |
| latent_feature = latent_feature.view(batch_size, height, width, -1) |
| latent_features.append(latent_feature) |
| |
| latent_features = torch.cat(latent_features, dim=-1) |
| return latent_features.permute(0, 3, 1, 2).contiguous() |
| else: |
| raise NotImplementedError(f"Unsupported indices shape: {indices.shape}") |
|
|
| def forward(self, features): |
| latent_features = [] |
| all_result_dicts = [] |
| chunk_size = features.shape[1] // self.num_codebooks |
| splited_features = features.split(chunk_size, dim=1) |
|
|
| for i, codebook in enumerate(self.codebooks): |
| latent_feature, result_dict = codebook(splited_features[i].float()) |
| latent_features.append(latent_feature.to(features.dtype)) |
| all_result_dicts.append(result_dict) |
| |
| |
| z_quantized = torch.cat(latent_features, dim=1) |
| |
| |
| global_quantizer_loss = sum(rd['quantizer_loss'] for rd in all_result_dicts) / self.num_codebooks |
| global_commitment_loss = sum(rd['commitment_loss'] for rd in all_result_dicts) / self.num_codebooks |
| global_codebook_loss = sum(rd['codebook_loss'] for rd in all_result_dicts) / self.num_codebooks |
| |
| |
| |
| |
| all_indices = torch.stack([rd['min_encoding_indices'] for rd in all_result_dicts], dim=1) |
| |
| result_dict = dict( |
| quantizer_loss=global_quantizer_loss, |
| commitment_loss=global_commitment_loss, |
| codebook_loss=global_codebook_loss, |
| min_encoding_indices=all_indices |
| ) |
| |
| return z_quantized, result_dict |