Wuerstchen
The Wuerstchen model drastically reduces computational costs by compressing the latent space by 42x, without compromising image quality and accelerating inference. During training, Wuerstchen uses two models (VQGAN + autoencoder) to compress the latents, and then a third model (text-conditioned latent diffusion model) is conditioned on this highly compressed space to generate an image.
To fit the prior model into GPU memory and to speedup training, try enabling gradient_accumulation_steps
, gradient_checkpointing
, and mixed_precision
respectively.
This guide explores the train_text_to_image_prior.py script to help you become more familiar with it, and how you can adapt it for your own use-case.
Before running the script, make sure you install the library from source:
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
Then navigate to the example folder containing the training script and install the required dependencies for the script you’re using:
cd examples/wuerstchen/text_to_image
pip install -r requirements.txt
🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It’ll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate Quick tour to learn more.
Initialize an 🤗 Accelerate environment:
accelerate config
To setup a default 🤗 Accelerate environment without choosing any configurations:
accelerate config default
Or if your environment doesn’t support an interactive shell, like a notebook, you can use:
from accelerate.utils import write_basic_config write_basic_config()
Lastly, if you want to train a model on your own dataset, take a look at the Create a dataset for training guide to learn how to create a dataset that works with the training script.
The following sections highlight parts of the training scripts that are important for understanding how to modify it, but it doesn’t cover every aspect of the script in detail. If you’re interested in learning more, feel free to read through the scripts and let us know if you have any questions or concerns.
Script parameters
The training scripts provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the parse_args()
function. It provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you’d like.
For example, to speedup training with mixed precision using the fp16 format, add the --mixed_precision
parameter to the training command:
accelerate launch train_text_to_image_prior.py \
--mixed_precision="fp16"
Most of the parameters are identical to the parameters in the Text-to-image training guide, so let’s dive right into the Wuerstchen training script!
Training script
The training script is also similar to the Text-to-image training guide, but it’s been modified to support Wuerstchen. This guide focuses on the code that is unique to the Wuerstchen training script.
The main()
function starts by initializing the image encoder - an EfficientNet - in addition to the usual scheduler and tokenizer.
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
image_encoder = EfficientNetEncoder()
image_encoder.load_state_dict(state_dict["effnet_state_dict"])
image_encoder.eval()
You’ll also load the WuerstchenPrior
model for optimization.
prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
optimizer = optimizer_cls(
prior.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
Next, you’ll apply some transforms to the images and tokenize the captions:
def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images]
examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
return examples
Finally, the training loop handles compressing the images to latent space with the EfficientNetEncoder
, adding noise to the latents, and predicting the noise residual with the WuerstchenPrior
model.
pred_noise = prior(noisy_latents, timesteps, prompt_embeds)
If you want to learn more about how the training loop works, check out the Understanding pipelines, models and schedulers tutorial which breaks down the basic pattern of the denoising process.
Launch the script
Once you’ve made all your changes or you’re okay with the default configuration, you’re ready to launch the training script! 🚀
Set the DATASET_NAME
environment variable to the dataset name from the Hub. This guide uses the Pokémon BLIP captions dataset, but you can create and train on your own datasets as well (see the Create a dataset for training guide).
To monitor training progress with Weights & Biases, add the --report_to=wandb
parameter to the training command. You’ll also need to add the --validation_prompt
to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
accelerate launch train_text_to_image_prior.py \
--mixed_precision="fp16" \
--dataset_name=$DATASET_NAME \
--resolution=768 \
--train_batch_size=4 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--dataloader_num_workers=4 \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--checkpoints_total_limit=3 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--validation_prompts="A robot pokemon, 4k photo" \
--report_to="wandb" \
--push_to_hub \
--output_dir="wuerstchen-prior-pokemon-model"
Once training is complete, you can use your newly trained model for inference!
import torch
from diffusers import AutoPipelineForText2Image
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
pipeline = AutoPipelineForText2Image.from_pretrained("path/to/saved/model", torch_dtype=torch.float16).to("cuda")
caption = "A cute bird pokemon holding a shield"
images = pipeline(
caption,
width=1024,
height=1536,
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
prior_guidance_scale=4.0,
num_images_per_prompt=2,
).images
Next steps
Congratulations on training a Wuerstchen model! To learn more about how to use your new model, the following may be helpful:
- Take a look at the Wuerstchen API documentation to learn more about how to use the pipeline for text-to-image generation and its limitations.