Model is broken when finetuning, and The CompVis and Diffusers models are identical
There are ostensibly 2 versions of this VAE:
https://huggingface.co/stabilityai/sd-vae-ft-mse
https://huggingface.co/stabilityai/sd-vae-ft-mse-original
The tables list the same file: https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt
And the vae-ft-mse-840000-ema-pruned.ckpt
under Files and versions
for both repos have the same sha256sum: c6a580b13a5bc05a5e16e4dbb80608ff2ec251a162311590c1f34c013d7f3dab
Unfortunately, this model is not compatible with the CompVis/latent-diffusion repo. When trying to load it in CompVis, I get the following error:
RuntimeError: Error(s) in loading state_dict for AutoencoderKL:
Missing key(s) in state_dict: "loss.logvar", "loss.perceptual_loss.scaling_layer.shift", "loss.perceptual_loss.scaling_layer.scale", "loss.perceptual_loss.net.slice1.0.weight", "loss.perceptual_loss.net.slice1.0.bias", "loss.perceptual_loss.net.slice1.2.weight", "loss.perceptual_loss.net.slice1.2.bias", "loss.perceptual_loss.net.slice2.5.weight", "loss.perceptual_loss.net.slice2.5.bias", "loss.perceptual_loss.net.slice2.7.weight", "loss.perceptual_loss.net.slice2.7.bias", "loss.perceptual_loss.net.slice3.10.weight", "loss.perceptual_loss.net.slice3.10.bias", "loss.perceptual_loss.net.slice3.12.weight", "loss.perceptual_loss.net.slice3.12.bias", "loss.perceptual_loss.net.slice3.14.weight", "loss.perceptual_loss.net.slice3.14.bias", "loss.perceptual_loss.net.slice4.17.weight", "loss.perceptual_loss.net.slice4.17.bias", "loss.perceptual_loss.net.slice4.19.weight", "loss.perceptual_loss.net.slice4.19.bias", "loss.perceptual_loss.net.slice4.21.weight", "loss.perceptual_loss.net.slice4.21.bias", "loss.perceptual_loss.net.slice5.24.weight", "loss.perceptual_loss.net.slice5.24.bias", "loss.perceptual_loss.net.slice5.26.weight", "loss.perceptual_loss.net.slice5.26.bias", "loss.perceptual_loss.net.slice5.28.weight", "loss.perceptual_loss.net.slice5.28.bias", "loss.perceptual_loss.lin0.model.1.weight", "loss.perceptual_loss.lin1.model.1.weight", "loss.perceptual_loss.lin2.model.1.weight", "loss.perceptual_loss.lin3.model.1.weight", "loss.perceptual_loss.lin4.model.1.weight", "loss.discriminator.main.0.weight", "loss.discriminator.main.0.bias", "loss.discriminator.main.2.weight", "loss.discriminator.main.3.weight", "loss.discriminator.main.3.bias", "loss.discriminator.main.3.running_mean", "loss.discriminator.main.3.running_var", "loss.discriminator.main.5.weight", "loss.discriminator.main.6.weight", "loss.discriminator.main.6.bias", "loss.discriminator.main.6.running_mean", "loss.discriminator.main.6.running_var", "loss.discriminator.main.8.weight", "loss.discriminator.main.9.weight", "loss.discriminator.main.9.bias", "loss.discriminator.main.9.running_mean", "loss.discriminator.main.9.running_var", "loss.discriminator.main.11.weight", "loss.discriminator.main.11.bias".
Unexpected key(s) in state_dict: "model_ema.decay", "model_ema.num_updates".
Gentle ping @multimodalart
Is there a pre-made script to train a diffusers VQVAE? Since they quantize encoder's output through a codebook it is not trivial to backpropagate. This implementation is quite complete (link : https://github.com/ritheshkumar95/pytorch-vqvae) however I have no clear idea how to wrap their VQVAE into a VQModel to later use diffusers pipelines. Would you have any idea @patrickvonplaten ? Thanks in advance
I would also know. There’s a repository of a guy who trains an Anime autoencoder here:
https://github.com/cccntu/fine-tune-models
The problem is that it’s not really easy to use and there are a lot of jax dependencies that I’m not familiar about.
I guess that people is training the autoencoder by using the latent diffusion repo by Compvis:
https://github.com/CompVis/latent-diffusion
To train a KL—8 encoder that’s swappable with the main encoder of stablediffusion (I guess one can just rename the ckpt to bin?)
To train a VQ model the “taming transformers” repo should work, but again that’s not entirely clear or documented well if that works for SD.