How to train a 16ch VAE decoder

Prerequisites

Read the MNIST example first. Get familiar with pytorch.

Training loop

Train your first module on color images.

for epoch in range(10):  # Loop over the dataset multiple times
    for i, data in enumerate(dataloader, 0):
        # Get the inputs; data is a list of [inputs, labels].
        inputs, labels = data

        # Zero the parameter gradients.
        optimizer.zero_grad()

        # Forward + backward + optimize.
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

Compare images

And measuring similarity.

Stable Diffusion calculates the loss:

  • by MSE or
  • by MSE + LPIPS * 0.1
import numpy as np
from PIL import Image
import torch
from torchmetrics.functional.image import learned_perceptual_image_patch_similarity, spectral_angle_mapper

pil_image = Image.open('vehicle.png').convert('RGB')
# The original image ("target"), and the predicted image ("output").
# In this example, it's a perfect match.
target = torch.from_numpy(np.array(pil_image, dtype=float) / 255.0)
output = target.clamp(-1.0, 1.0).float()
loss_sam = spectral_angle_mapper(output, target)
loss_lpips = learned_perceptual_image_patch_similarity(output, target)

loss_mse = torch.nn.functional.mse_loss(output, target, reduction='mean')
loss = loss_mse + loss_lpips * 0.1

Parquet

Load a downloaded dataset or write your own dataset loader.

import numpy as np
from torch.utils.data import Dataset
from torchvision.transforms import Compose
from torchvision.transforms.functional import pil_to_tensor

class ImageDataset(Dataset):
    def __init__(self, transform: Compose):
        super().__init__()
        self.transform = transform
        # Initialize image sources.
        ...
    
    def __getitem__(self, idx):
        # Load your images from urls, folders or downloaded datasets.
        image = ...
        image = image.convert('RGB')
        # All images must have the same size.
        assert image.width == image.height
        return {'image': self.transform(image),
                'pil_image': image,
                'image_array': np.array(image),
                'image_tensor': pil_to_tensor(image)}        

Dataloader

from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder

# Call by transform(pil_image).
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Option A.
dataset = load_dataset(...)
dataset.set_transform(transform)
dataloader = DataLoader(dataset, batch_size=2)
# Option B.
dataloader = DataLoader(ImageDataset(transform), batch_size=2)
# Option C.
dataloader = DataLoader(ImageFolder('path/to/image/folder', transform=transform), batch_size=2)

Rescaling the latent space

In the latent space of a KL-regularized model, the model tends to allocate a lot of semantic detail early on. This means that the model is capturing too much information too quickly.

By rescaling the latent space values by their component-wise standard deviation, the authors are effectively reducing the SNR.

import diffusers
from PIL import Image
import torch
import torchvision.transforms
from torchvision.transforms.functional import to_pil_image

def encode_image(image):
    if isinstance(image, Image.Image):
        # Transform, normalize the image.
        image = transform(image)
    image_tensor = image * 2.0 - 1.0
    image_tensor = image_tensor.to(cuda_device)
    latent = vae.encode(image_tensor)

    # SD1.5, SDXL
    # vae.config.scaling_factor * latent.latent_dist.sample(), image
    
    # SD3
    return vae.config.scaling_factor * (latent.latent_dist.sample() - vae.config.shift_factor), image 

def decode_image(latent: torch.Tensor):
    # SD1.5, SDXL
    # latent = 1 / vae.config.scaling_factor * latent

    # SD3
    latent = 1 / vae.config.scaling_factor * latent + vae.config.shift_factor 
    
    image_tensor = vae.decode(latent)
    
    return diffusers.utils.pt_to_pil(image_tensor)

Regression model

In the previous training exercises, you focused on classification tasks, where the goal was to predict a categorical label or class from a set of predefined categories. Unlike classification, regression models aim to establish a relationship between the inputs and outputs, makes predictions on a continuous scale.

Copy the sd3_impls python file to your folder. You will have two models, a pretrained stable diffusion model, and your VAE decoder.

from diffusers import AutoencoderKL
from safetensors.torch import load_model, save_model
from sd3_impls import VAEDecoder
import torch
from torch import nn

class SDVAE(nn.Module):
    def __init__(self):
        # Choose between the (newer) torch.bfloat16 and (older) torch.float32. 
        self.decoder = VAEDecoder(dtype=torch.float32)

    def decode(self, latent):
        return self.decoder(latent)

cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
vae = AutoencoderKL.from_pretrained(
    'path/to/pretrained/model',
    subfolder='vae',
    revision=None,
    variant=None
).to(cuda_device)
vae.requires_grad_(False)  # Otherwise OOM.
# Option A.
# Train a new model.
model = SDVAE().to(cuda_device)
# Option B.
# Continue training from your previous checkpoint. 
model = SDVAE().to(cuda_device)
load_model(model, 'vae.safetensors')

dataloader = ...

optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)
for epoch in range(10):
    for batch in dataloader:
        # Read the 'image' column from the dataset,
        # preferably it will already be a normalized tensor.
        norm_img, *_ = batch.values()
        optimizer.zero_grad()
        # Exercise to the reader:
        # Encode all images before the training loop,
        # store them in the latents[] list.
        ...
        # Forward.
        output = model.decode(model_input)
        ...
        # Compare tensors, the output of the training model will be a bit blurry.
        # Exercise to the reader:
        # use the MSE + LPIPS * 0.1 formula to improve image fidelity.
        loss = torch.nn.functional.mse_loss(output[:, :3, :, :].float(),
                                            norm_img[:, :3, :, :].float(),
                                            reduction='mean')
        loss.backward()
        optimizer.step()
    save_model(model, './epoch_{:03d}.safetensors'.format(epoch))

save_model(model, 'vae.safetensors')

Note to model creators

Open-source training code and tagged datasets are essential for the community. Why don't you share your know-how starting from now?

Downloads last month

-

Downloads are not tracked for this model. How to track
Unable to determine this model's library. Check the docs .