import os import torch import torch.nn as nn import torch.nn.functional as F import einops from einops.layers.torch import Rearrange def normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) def swish(x): return x*torch.sigmoid(x) class ResBlock(nn.Module): def __init__(self, in_channels, out_channels=None, activation_fn="relu"): super(ResBlock, self).__init__() self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels self.norm1 = normalize(in_channels) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.norm2 = normalize(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) if self.in_channels != self.out_channels: self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False) self.activation_fn = activation_fn if activation_fn=="relu": self.actn = nn.ReLU() def forward(self, x_in): x = x_in x = self.norm1(x) if self.activation_fn=="relu": x = self.actn(x) elif self.activation_fn=="swish": x = swish(x) x = self.conv1(x) x = self.norm2(x) if self.activation_fn=="relu": x = self.actn(x) elif self.activation_fn=="swish": x = swish(x) x = self.conv2(x) if self.in_channels != self.out_channels: x_in = self.conv_out(x_in) return x + x_in class Encoder(nn.Module): def __init__(self, ): super().__init__() self.filters = 128 self.num_res_blocks = 2 self.ch_mult = [1,1,2,2,4] self.in_ch_mult = (1,)+tuple(self.ch_mult) self.embedding_dim = 32 self.conv_downsample = False self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, bias=False) blocks = [] for i in range(len(self.ch_mult)): block_in_ch = self.filters * self.in_ch_mult[i] block_out_ch = self.filters * self.ch_mult[i] for _ in range(self.num_res_blocks): blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) block_in_ch = block_out_ch for _ in range(self.num_res_blocks): blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) self.norm1 = normalize(block_in_ch) self.conv2 = nn.Conv2d(block_in_ch, self.embedding_dim, kernel_size=1, stride=1, padding=0) self.blocks = nn.ModuleList(blocks) def forward(self, x): x = self.conv1(x) for i in range(len(self.ch_mult)): for j in range(self.num_res_blocks): x = self.blocks[i*2+j](x) if i < len(self.ch_mult) -1: x = torch.nn.functional.avg_pool2d(x, (2,2),(2,2)) x = self.blocks[-2](x) x = self.blocks[-1](x) x = self.norm1(x) x = swish(x) x = self.conv2(x) return x class VectorQuantizer(nn.Module): def __init__(self, codebook_size=8192, emb_dim=32, beta=None): super(VectorQuantizer, self).__init__() self.codebook_size = codebook_size # number of embeddings self.emb_dim = emb_dim # dimension of embedding self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) self.beta=0.0 self.z_dim = emb_dim def forward(self, z): # preprocess b, c, h, w = z.size() flatten = z.permute(0, 2, 3, 1).reshape(-1, c) codebook = self.embedding.weight with torch.no_grad(): tokens = torch.cdist(flatten, codebook).argmin(dim=1) quantized = F.embedding(tokens, codebook).view(b, h, w, c).permute(0, 3, 1, 2) # compute loss codebook_loss = F.mse_loss(quantized, z.detach()) commitment_loss = F.mse_loss(quantized.detach(), z) loss = codebook_loss + self.beta * commitment_loss # perplexity counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype) # dist.all_reduce(counts) p = counts / counts.sum() perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10))) # postprocess tokens = tokens.view(b, h, w) quantized = z + (quantized - z).detach() # quantized_2 = self.get_codebook_feat(tokens, (b, h, w, c)) return quantized, tokens, loss, perplexity def get_codebook_feat(self, indices, shape=None): # input indices: batch*token_num -> (batch*token_num)*1 # shape: batch, height, width, channel indices = indices.view(-1,1) min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) min_encodings.scatter_(1, indices, 1) # get quantized latent vectors z_q = torch.matmul(min_encodings.float(), self.embedding.weight) if shape is not None: # reshape back to match original input shape z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() return z_q class Decoder(nn.Module): def __init__(self,): super().__init__() self.filters = 128 self.num_res_blocks = 2 self.ch_mult = [1,1,2,2,4] self.in_ch_mult = (1,)+tuple(self.ch_mult) self.embedding_dim =32 self.out_channels = 3 self.in_channels = self.embedding_dim self.conv_downsample = False self.conv1 = nn.Conv2d(32, 512, kernel_size=3, stride=1, padding=1) blocks = [] block_in_ch = self.filters * self.ch_mult[-1] block_out_ch = self.filters * self.ch_mult[-1] #blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) for _ in range(self.num_res_blocks): blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) upsample_conv_layers = [] for i in reversed(range(len(self.ch_mult))): block_out_ch = self.filters * self.ch_mult[i] for _ in range(self.num_res_blocks): blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) block_in_ch = block_out_ch if i > 0: upsample_conv_layers.append(nn.Conv2d(block_in_ch, block_out_ch*4, kernel_size=3, stride=1, padding=1)) self.upsample = Rearrange("b h w (h2 w2 c) -> b (h h2) (w w2) c", h2=2, w2=2) self.norm1 = normalize(block_in_ch) # self.act_fn self.conv6 = nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1) self.blocks = nn.ModuleList(blocks) self.up_convs = nn.ModuleList(upsample_conv_layers) def forward(self, x): x = self.conv1(x) x = self.blocks[0](x) x = self.blocks[1](x) for i in range(len(self.ch_mult)): for j in range(self.num_res_blocks): x = self.blocks[2+i*2+j](x) if i < len(self.ch_mult)-1: x = self.up_convs[i](x) #print("pre: x.size()",x.size()) x = x.permute(0,2,3,1) x = self.upsample(x) x = x.permute(0,3,1,2) #print("post: x.size()", x.size()) x = self.norm1(x) x = swish(x) x = self.conv6(x) return x class VQVAE(nn.Module): def __init__(self, ): super().__init__() self.encoder = Encoder() self.quantizer = VectorQuantizer() self.decoder = Decoder() def forward(self, x): x = self.encoder(x) quant,tokens, loss, perplexity = self.quantizer(x) x = self.decoder(quant) return x def tokenize(self, x): batch_shape = x.shape[:-3] x = x.reshape(-1, *x.shape[-3:]) x = self.encoder(x) quant,tokens, loss, perplexity = self.quantizer(x) return tokens.reshape(*batch_shape, *tokens.shape[1:]) def decode(self, tokens): tokens = einops.rearrange(tokens, 'b ... -> b (...)') b = tokens.shape[0] if tokens.shape[-1] == 256: hw = 16 elif tokens.shape[-1] == 224: hw = 14 else: raise ValueError("Invalid tokens shape") quant = self.quantizer.get_codebook_feat(tokens, (b, hw, hw, 32)) x = self.decoder(quant) return x class VAEDecoder(nn.Module): def __init__(self, ): super().__init__() self.quantizer = VectorQuantizer() self.decoder = Decoder() def forward(self, x): quant = self.quantizer.get_codebook_feat(x,(1,14,14,32)) x = self.decoder(quant) return x def get_tokenizer(): checkpoint_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "xh_ckpt.pth" ) torch_state_dict = torch.load(checkpoint_path) net = VQVAE() net.load_state_dict(torch_state_dict) return net