| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Simple implementation of AutoEncoderVQ.""" |
|
|
| import torch |
| from torch import nn |
|
|
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.models.modeling_outputs import AutoencoderKLOutput |
| from diffusers.models.modeling_utils import ModelMixin |
|
|
| from diffnext.models.autoencoders.autoencoder_kl import Attention, Decoder, Encoder |
| from diffnext.models.autoencoders.modeling_utils import DecoderOutput, IdentityDistribution |
| from diffnext.models.autoencoders import quantizers |
|
|
|
|
| class AutoencoderVQ(ModelMixin, ConfigMixin): |
| """AutoEncoder VQ.""" |
|
|
| @register_to_config |
| def __init__( |
| self, |
| in_channels=3, |
| out_channels=3, |
| down_block_types=("DownEncoderBlock2D",) * 4, |
| up_block_types=("UpDecoderBlock2D",) * 4, |
| block_out_channels=(128, 256, 512, 512), |
| layers_per_block=2, |
| act_fn="silu", |
| latent_channels=16, |
| norm_num_groups=32, |
| sample_size=1024, |
| num_vq_embeddings=16384, |
| vq_embed_dim=8, |
| attn_down_block=False, |
| attn_up_block=False, |
| force_upcast=False, |
| temporal_stride=1, |
| spatial_stride=16, |
| decoder_dtype=None, |
| _quantizer_name="VQuantizer", |
| ): |
| super(AutoencoderVQ, self).__init__() |
| channels, layers = block_out_channels, layers_per_block |
| self.encoder = Encoder(in_channels, latent_channels, channels, layers) |
| self.decoder = Decoder(latent_channels, out_channels, channels, layers) |
| self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) |
| self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) |
| if attn_down_block: |
| attentions = [Attention(block_out_channels[-1]) for _ in range(layers_per_block)] |
| self.encoder.down_blocks[-1].attentions += attentions |
| if attn_up_block: |
| attentions = [Attention(block_out_channels[-1]) for _ in range(layers_per_block + 1)] |
| self.decoder.up_blocks[0].attentions += attentions |
| self.quantizer = getattr(quantizers, _quantizer_name)(num_vq_embeddings, vq_embed_dim) |
| self.latent_dist = IdentityDistribution |
|
|
| def to(self, *args, **kwargs): |
| """Convert to given device and dtype.""" |
| super().to(*args, **kwargs) |
| if self.config.decoder_dtype: |
| self.decoder.to(dtype=getattr(torch, self.config.decoder_dtype)) |
| return self |
|
|
| def scale_(self, x) -> torch.Tensor: |
| """Scale the input latents.""" |
| return x |
|
|
| def unscale_(self, x) -> torch.Tensor: |
| """Unscale the input latents.""" |
| return x |
|
|
| def encode(self, x) -> AutoencoderKLOutput: |
| """Encode the input samples.""" |
| z = self.encoder(self.forward(x)) |
| z = self.quant_conv(z) |
| posterior = self.latent_dist(self.quantizer.quantize(z)) |
| return AutoencoderKLOutput(latent_dist=posterior) |
|
|
| def decode(self, ids) -> DecoderOutput: |
| """Decode the input indices.""" |
| z = self.quantizer.dequantize(ids) |
| t = z.size(2) if z.dim() == 5 else 1 |
| z = z.transpose(1, 2).flatten(0, 1) if t > 1 else z |
| z = z.squeeze_(2) if z.dim() == 5 else z |
| x = self.post_quant_conv(self.forward(z)) |
| x = self.decoder(x.to(self.decoder.conv_in.weight)) |
| x = x.view(-1, t, *x.shape[1:]).transpose(1, 2) if t > 1 else x |
| return DecoderOutput(sample=x) |
|
|
| def forward(self, x): |
| return x |
|
|