|
|
import torch |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
from diffusers.models.autoencoders.autoencoder_kl import Encoder, Decoder, AutoencoderKL |
|
|
from typing import Optional |
|
|
|
|
|
class MyEncoder(Encoder): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels = 3, |
|
|
out_channels = 3, |
|
|
down_block_types = ..., |
|
|
block_out_channels = ..., |
|
|
layers_per_block = 2, |
|
|
norm_num_groups = 32, |
|
|
act_fn = "silu", |
|
|
double_z = True, |
|
|
mid_block_add_attention=True |
|
|
): |
|
|
super().__init__( |
|
|
in_channels, out_channels, down_block_types, block_out_channels, |
|
|
layers_per_block, norm_num_groups, act_fn, double_z, mid_block_add_attention |
|
|
) |
|
|
|
|
|
def forward(self, sample: torch.Tensor) -> torch.Tensor: |
|
|
r"""The forward method of the `Encoder` class.""" |
|
|
|
|
|
sample = self.conv_in(sample) |
|
|
|
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
|
|
|
for down_block in self.down_blocks: |
|
|
sample = checkpoint(self.ckpt_wrapper(down_block), sample) |
|
|
|
|
|
sample = checkpoint(self.ckpt_wrapper(self.mid_block), sample) |
|
|
|
|
|
else: |
|
|
|
|
|
for down_block in self.down_blocks: |
|
|
sample = down_block(sample) |
|
|
|
|
|
|
|
|
sample = self.mid_block(sample) |
|
|
|
|
|
|
|
|
sample = self.conv_norm_out(sample) |
|
|
sample = self.conv_act(sample) |
|
|
sample = self.conv_out(sample) |
|
|
|
|
|
return sample |
|
|
|
|
|
def ckpt_wrapper(self, module): |
|
|
def ckpt_forward(*inputs): |
|
|
outputs = module(*inputs) |
|
|
return outputs |
|
|
return ckpt_forward |
|
|
|
|
|
|
|
|
class MyDecoder(Decoder): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels = 3, |
|
|
out_channels = 3, |
|
|
up_block_types = ..., |
|
|
block_out_channels = ..., |
|
|
layers_per_block = 2, |
|
|
norm_num_groups = 32, |
|
|
act_fn = "silu", |
|
|
norm_type = "group", |
|
|
mid_block_add_attention=True |
|
|
): |
|
|
super().__init__( |
|
|
in_channels, out_channels, up_block_types, block_out_channels, |
|
|
layers_per_block, norm_num_groups, act_fn, norm_type, mid_block_add_attention |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
sample: torch.Tensor, |
|
|
latent_embeds: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
r"""The forward method of the `Decoder` class.""" |
|
|
|
|
|
sample = self.conv_in(sample) |
|
|
|
|
|
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype |
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
|
|
|
sample = checkpoint(self.ckpt_wrapper(self.mid_block), sample, latent_embeds) |
|
|
sample = sample.to(upscale_dtype) |
|
|
|
|
|
|
|
|
for up_block in self.up_blocks: |
|
|
sample = checkpoint(self.ckpt_wrapper(up_block), sample, latent_embeds) |
|
|
else: |
|
|
|
|
|
sample = self.mid_block(sample, latent_embeds) |
|
|
sample = sample.to(upscale_dtype) |
|
|
|
|
|
|
|
|
for up_block in self.up_blocks: |
|
|
sample = up_block(sample, latent_embeds) |
|
|
|
|
|
|
|
|
if latent_embeds is None: |
|
|
sample = self.conv_norm_out(sample) |
|
|
else: |
|
|
sample = self.conv_norm_out(sample, latent_embeds) |
|
|
sample = self.conv_act(sample) |
|
|
sample = self.conv_out(sample) |
|
|
|
|
|
return sample |
|
|
|
|
|
def ckpt_wrapper(self, module): |
|
|
def ckpt_forward(*inputs): |
|
|
outputs = module(*inputs) |
|
|
return outputs |
|
|
return ckpt_forward |
|
|
|
|
|
|
|
|
class MyAutoencoderKL(AutoencoderKL): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels = 3, |
|
|
out_channels = 3, |
|
|
down_block_types = ..., |
|
|
up_block_types = ..., |
|
|
block_out_channels = ..., |
|
|
layers_per_block = 1, |
|
|
act_fn = "silu", |
|
|
latent_channels = 4, |
|
|
norm_num_groups = 32, |
|
|
sample_size = 32, |
|
|
scaling_factor = 0.18215, |
|
|
shift_factor = None, |
|
|
latents_mean = None, |
|
|
latents_std = None, |
|
|
force_upcast = True, |
|
|
use_quant_conv = True, |
|
|
use_post_quant_conv = True, |
|
|
mid_block_add_attention = True, |
|
|
bn_momentum = 0.1, |
|
|
): |
|
|
super().__init__( |
|
|
in_channels, out_channels, down_block_types, up_block_types, block_out_channels, |
|
|
layers_per_block, act_fn, latent_channels, norm_num_groups, sample_size, |
|
|
scaling_factor, shift_factor, latents_mean, latents_std, force_upcast, |
|
|
use_quant_conv, use_post_quant_conv, mid_block_add_attention |
|
|
) |
|
|
self.encoder = MyEncoder( |
|
|
in_channels=in_channels, |
|
|
out_channels=latent_channels, |
|
|
down_block_types=down_block_types, |
|
|
block_out_channels=block_out_channels, |
|
|
layers_per_block=layers_per_block, |
|
|
act_fn=act_fn, |
|
|
norm_num_groups=norm_num_groups, |
|
|
double_z=True, |
|
|
mid_block_add_attention=mid_block_add_attention, |
|
|
) |
|
|
|
|
|
|
|
|
self.decoder = MyDecoder( |
|
|
in_channels=latent_channels, |
|
|
out_channels=out_channels, |
|
|
up_block_types=up_block_types, |
|
|
block_out_channels=block_out_channels, |
|
|
layers_per_block=layers_per_block, |
|
|
norm_num_groups=norm_num_groups, |
|
|
act_fn=act_fn, |
|
|
mid_block_add_attention=mid_block_add_attention, |
|
|
) |
|
|
self.bn = torch.nn.BatchNorm2d( |
|
|
latent_channels, eps=1e-4, momentum=bn_momentum, affine=False, track_running_stats=True |
|
|
) |
|
|
self.bn.reset_running_stats() |
|
|
self.init_bn() |
|
|
|
|
|
|
|
|
def init_bn(self): |
|
|
|
|
|
|
|
|
self.bn.running_mean = torch.zeros_like(self.bn.running_mean) |
|
|
self.bn.running_var = torch.ones_like(self.bn.running_var) / self.config.scaling_factor ** 2 |
|
|
|
|
|
@property |
|
|
def mean(self): |
|
|
mean = self.bn.running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) |
|
|
return mean |
|
|
|
|
|
@property |
|
|
def std(self): |
|
|
std = self.bn.running_var.sqrt().unsqueeze(0).unsqueeze(-1).unsqueeze(-1) |
|
|
return std |
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor, use_checkpoint=False): |
|
|
self.encoder.gradient_checkpointing = use_checkpoint |
|
|
self.decoder.gradient_checkpointing = use_checkpoint |
|
|
posterior = self.encode(x).latent_dist |
|
|
z = posterior.sample() |
|
|
latent = self.bn(z) |
|
|
recon = self.decode(z).sample |
|
|
return posterior, latent, recon |
|
|
|
|
|
|
|
|
|
|
|
|