ColorizeDiffusion / ldm /models /autoencoder.py
tellurion's picture
Add refnet/models and ldm/models source files previously excluded by .gitignore
47ab351
import torch
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
class AutoencoderKL(torch.nn.Module):
def __init__(
self,
ddconfig,
embed_dim
):
super().__init__()
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def get_last_layer(self):
return self.decoder.conv_out.weight
@property
def dtype(self):
return self.decoder.conv_out.weight.dtype