Diffusers documentation

ControlNet

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.35.1).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

ControlNet

ControlNet 是一种基于预训练模型的适配器架构。它通过额外输入的条件图像(如边缘检测图、深度图、人体姿态图等),实现对生成图像的精细化控制。

在显存有限的GPU上训练时,建议启用训练命令中的 gradient_checkpointing(梯度检查点)、gradient_accumulation_steps(梯度累积步数)和 mixed_precision(混合精度)参数。还可使用 xFormers 的内存高效注意力机制进一步降低显存占用。虽然JAX/Flax训练支持在TPU和GPU上高效运行,但不支持梯度检查点和xFormers。若需通过Flax加速训练,建议使用显存大于30GB的GPU。

本指南将解析 train_controlnet.py 训练脚本,帮助您理解其逻辑并适配自定义需求。

运行脚本前,请确保从源码安装库:

git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .

然后进入包含训练脚本的示例目录,安装所需依赖:

PyTorch
Flax
cd examples/controlnet
pip install -r requirements.txt

🤗 Accelerate 是一个支持多GPU/TPU训练和混合精度的库,它能根据硬件环境自动配置训练方案。参阅 🤗 Accelerate 快速入门 了解更多。

初始化🤗 Accelerate环境:

accelerate config

若要创建默认配置(不进行交互式选择):

accelerate config default

若环境不支持交互式shell(如notebook),可使用:

from accelerate.utils import write_basic_config

write_basic_config()

最后,如需训练自定义数据集,请参阅 创建训练数据集 指南了解数据准备方法。

下文重点解析脚本中的关键模块,但不会覆盖所有实现细节。如需深入了解,建议直接阅读 脚本源码,如有疑问欢迎反馈。

脚本参数

训练脚本提供了丰富的可配置参数,所有参数及其说明详见 parse_args() 函数。虽然该函数已为每个参数提供默认值(如训练批大小、学习率等),但您可以通过命令行参数覆盖这些默认值。

例如,使用fp16混合精度加速训练, 可使用--mixed_precision参数

accelerate launch train_controlnet.py \
  --mixed_precision="fp16"

基础参数说明可参考 文生图 训练指南,此处重点介绍ControlNet相关参数:

  • --max_train_samples: 训练样本数量,减少该值可加快训练,但对超大数据集需配合 --streaming 参数使用
  • --gradient_accumulation_steps: 梯度累积步数,通过分步计算实现显存受限情况下的更大批次训练

Min-SNR加权策略

Min-SNR 加权策略通过重新平衡损失函数加速模型收敛。虽然训练脚本支持预测 epsilon(噪声)或 v_prediction,但Min-SNR对两种预测类型均兼容。该策略仅适用于PyTorch版本,Flax训练脚本暂不支持。

推荐值设为5.0:

accelerate launch train_controlnet.py \
  --snr_gamma=5.0

训练脚本

与参数说明类似,训练流程的通用解析可参考 文生图 指南。此处重点分析ControlNet特有的实现。

脚本中的 make_train_dataset 函数负责数据预处理,除常规的文本标注分词和图像变换外,还包含条件图像的特效处理:

在TPU上流式加载数据集时,🤗 Datasets库可能成为性能瓶颈(因其未针对图像数据优化)。建议考虑 WebDatasetTorchDataTensorFlow Datasets 等高效数据格式。

conditioning_image_transforms = transforms.Compose(
    [
        transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(args.resolution),
        transforms.ToTensor(),
    ]
)

main() 函数中,代码会加载分词器、文本编码器、调度器和模型。此处也是ControlNet模型的加载点(支持从现有权重加载或从UNet随机初始化):

if args.controlnet_model_name_or_path:
    logger.info("Loading existing controlnet weights")
    controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
else:
    logger.info("Initializing controlnet weights from unet")
    controlnet = ControlNetModel.from_unet(unet)

优化器 专门针对ControlNet参数进行更新:

params_to_optimize = controlnet.parameters()
optimizer = optimizer_class(
    params_to_optimize,
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,
)

训练循环 中,条件文本嵌入和图像被输入到ControlNet的下采样和中层模块:

encoder_hidden_states = text_encoder(batch["input_ids"])[0]
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)

down_block_res_samples, mid_block_res_sample = controlnet(
    noisy_latents,
    timesteps,
    encoder_hidden_states=encoder_hidden_states,
    controlnet_cond=controlnet_image,
    return_dict=False,
)

若想深入理解训练循环机制,可参阅 理解管道、模型与调度器 教程,该教程详细解析了去噪过程的基本原理。

启动训练

现在可以启动训练脚本了!🚀

本指南使用 fusing/fill50k 数据集,当然您也可以按照 创建训练数据集 指南准备自定义数据。

设置环境变量 MODEL_NAME 为Hub模型ID或本地路径,OUTPUT_DIR 为模型保存路径。

下载训练用的条件图像:

wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png

根据GPU型号,可能需要启用特定优化。默认配置需要约38GB显存。若使用多GPU训练,请在 accelerate launch 命令中添加 --multi_gpu 参数。

16GB
12GB
8GB

16GB显卡可使用bitsandbytes 8-bit优化器和梯度检查点:

pip install bitsandbytes

训练命令添加以下参数:

accelerate launch train_controlnet.py \
  --gradient_checkpointing \
  --use_8bit_adam \
PyTorch
Flax
export MODEL_DIR="stable-diffusion-v1-5/stable-diffusion-v1-5"
export OUTPUT_DIR="path/to/save/model"

accelerate launch train_controlnet.py \
 --pretrained_model_name_or_path=$MODEL_DIR \
 --output_dir=$OUTPUT_DIR \
 --dataset_name=fusing/fill50k \
 --resolution=512 \
 --learning_rate=1e-5 \
 --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
 --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
 --train_batch_size=1 \
 --gradient_accumulation_steps=4 \
 --push_to_hub

训练完成后即可进行推理:

from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers.utils import load_image
import torch

controlnet = ControlNetModel.from_pretrained("path/to/controlnet", torch_dtype=torch.float16)
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
    "path/to/base/model", controlnet=controlnet, torch_dtype=torch.float16
).to("cuda")

control_image = load_image("./conditioning_image_1.png")
prompt = "pale golden rod circle with old lace background"

generator = torch.manual_seed(0)
image = pipeline(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0]
image.save("./output.png")

Stable Diffusion XL

Stable Diffusion XL (SDXL) 是新一代文生图模型,通过添加第二文本编码器支持生成更高分辨率图像。使用 train_controlnet_sdxl.py 脚本可为SDXL训练ControlNet适配器。

SDXL训练脚本的详细解析请参阅 SDXL训练 指南。

后续步骤

恭喜完成ControlNet训练!如需进一步了解模型应用,以下指南可能有所帮助:

< > Update on GitHub