TRL documentation

Denoising Diffusion Policy Optimization

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.12.1).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Denoising Diffusion Policy Optimization

The why

Before After DDPO finetuning

Getting started with Stable Diffusion finetuning with reinforcement learning

The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace’s diffusers library. A reason for stating this is that getting started requires a bit of familiarity with the diffusers library concepts, mainly two of them - pipelines and schedulers. Right out of the box (diffusers library), there isn’t a Pipeline nor a Scheduler instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made.

There is a pipeline interface that is provided by this library that is required to be implemented to be used with the DDPOTrainer, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. Note: Only the StableDiffusion architecture is supported at this point. There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide.

The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO).

For a more detailed look into the interface and the associated default implementation, go here

Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren’t as finicky as non-LORA based training.

Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.

Getting started with examples/scripts/ddpo.py

The ddpo.py script is a working example of using the DDPO trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (DDPOConfig).

Note: one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor.

Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a huggingface user access token that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running

python ddpo.py --hf_user_access_token <token>

To obtain the documentation of stable_diffusion_tuning.py, please run python stable_diffusion_tuning.py --help

The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)

  • The configurable sample batch size (--ddpo_config.sample_batch_size=6) should be greater than or equal to the configurable training batch size (--ddpo_config.train_batch_size=3)
  • The configurable sample batch size (--ddpo_config.sample_batch_size=6) must be divisible by the configurable train batch size (--ddpo_config.train_batch_size=3)
  • The configurable sample batch size (--ddpo_config.sample_batch_size=6) must be divisible by both the configurable gradient accumulation steps (--ddpo_config.train_gradient_accumulation_steps=1) and the configurable accelerator processes count

Setting up the image logging hook function

Expect the function to be given a list of lists of the form

[[image, prompt, prompt_metadata, rewards, reward_metadata], ...]

and image, prompt, prompt_metadata, rewards, reward_metadata are batched. The last list in the lists of lists represents the last sample batch. You are likely to want to log this one While you are free to log however you want the use of wandb or tensorboard is recommended.

Key terms

  • rewards : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
  • reward_metadata : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward
  • prompt : The prompt is the text that is used to generate the image
  • prompt_metadata : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a FLAVA setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
  • image : The image generated by the Stable Diffusion model

Example code for logging sampled images with wandb is given below.

# for logging these images to wandb

def image_outputs_hook(image_data, global_step, accelerate_logger):
    # For the sake of this example, we only care about the last batch
    # hence we extract the last element of the list
    result = {}
    images, prompts, _, rewards, _ = image_data[-1]
    for i, image in enumerate(images):
        pil = Image.fromarray(
            (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
        )
        pil = pil.resize((256, 256))
        result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
    accelerate_logger.log_images(
        result,
        step=global_step,
    )

Using the finetuned model

Assuming you’ve done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows


import torch
from trl import DefaultDDPOStableDiffusionPipeline

pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# memory optimization
pipeline.vae.to(device, torch.float16)
pipeline.text_encoder.to(device, torch.float16)
pipeline.unet.to(device, torch.float16)

prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)

for prompt, image in zip(prompts,results.images):
    image.save(f"{prompt}.png")

Credits

This work is heavily influenced by the repo here and the associated paper Training Diffusion Models with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine.

DDPOTrainer

class trl.DDPOTrainer

< >

( config: DDPOConfig reward_function: typing.Callable[[torch.Tensor, typing.Tuple[str], typing.Tuple[typing.Any]], torch.Tensor] prompt_function: typing.Callable[[], typing.Tuple[str, typing.Any]] sd_pipeline: DDPOStableDiffusionPipeline image_samples_hook: typing.Optional[typing.Callable[[typing.Any, typing.Any, typing.Any], typing.Any]] = None )

Parameters

  • **config** (DDPOConfig) — Configuration object for DDPOTrainer. Check the documentation of PPOConfig for more — details.
  • **reward_function** (Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor]) — Reward function to be used —
  • **prompt_function** (Callable[[], Tuple[str, Any]]) — Function to generate prompts to guide model —
  • **sd_pipeline** (DDPOStableDiffusionPipeline) — Stable Diffusion pipeline to be used for training. —
  • **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) — Hook to be called to log images —

The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch As of now only Stable Diffusion based pipelines are supported

calculate_loss

< >

( latents timesteps next_latents log_probs advantages embeds )

Parameters

  • latents (torch.Tensor) — The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
  • timesteps (torch.Tensor) — The timesteps sampled from the diffusion model, shape: [batch_size]
  • next_latents (torch.Tensor) — The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
  • log_probs (torch.Tensor) — The log probabilities of the latents, shape: [batch_size]
  • advantages (torch.Tensor) — The advantages of the latents, shape: [batch_size]
  • embeds (torch.Tensor) — The embeddings of the prompts, shape: [2*batch_size or batch_size, …] Note: the “or” is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds

Calculate the loss for a batch of an unpacked sample

create_model_card

< >

( model_name: typing.Optional[str] = None dataset_name: typing.Optional[str] = None tags: typing.Union[str, typing.List[str], NoneType] = None )

Parameters

  • model_name (str, optional, defaults to None) — The name of the model.
  • dataset_name (str, optional, defaults to None) — The name of the dataset used for training.
  • tags (str, List[str] or None, optional, defaults to None) — Tags to be associated with the model card.

Creates a draft of a model card using the information available to the Trainer.

step

< >

( epoch: int global_step: int ) global_step (int)

Parameters

  • epoch (int) — The current epoch.
  • global_step (int) — The current global step.

Returns

global_step (int)

The updated global step.

Perform a single step of training.

Side Effects:

  • Model weights are updated
  • Logs the statistics to the accelerator trackers.
  • If self.image_samples_callback is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.

train

< >

( epochs: typing.Optional[int] = None )

Train the model for a given number of epochs

DDPOConfig

class trl.DDPOConfig

< >

( exp_name: str = 'doc-buil' run_name: str = '' seed: int = 0 log_with: typing.Optional[typing.Literal['wandb', 'tensorboard']] = None tracker_kwargs: dict = <factory> accelerator_kwargs: dict = <factory> project_kwargs: dict = <factory> tracker_project_name: str = 'trl' logdir: str = 'logs' num_epochs: int = 100 save_freq: int = 1 num_checkpoint_limit: int = 5 mixed_precision: str = 'fp16' allow_tf32: bool = True resume_from: str = '' sample_num_steps: int = 50 sample_eta: float = 1.0 sample_guidance_scale: float = 5.0 sample_batch_size: int = 1 sample_num_batches_per_epoch: int = 2 train_batch_size: int = 1 train_use_8bit_adam: bool = False train_learning_rate: float = 0.0003 train_adam_beta1: float = 0.9 train_adam_beta2: float = 0.999 train_adam_weight_decay: float = 0.0001 train_adam_epsilon: float = 1e-08 train_gradient_accumulation_steps: int = 1 train_max_grad_norm: float = 1.0 train_num_inner_epochs: int = 1 train_cfg: bool = True train_adv_clip_max: float = 5.0 train_clip_range: float = 0.0001 train_timestep_fraction: float = 1.0 per_prompt_stat_tracking: bool = False per_prompt_stat_tracking_buffer_size: int = 16 per_prompt_stat_tracking_min_count: int = 16 async_reward_computation: bool = False max_workers: int = 2 negative_prompts: str = '' push_to_hub: bool = False )

Parameters

  • exp_name (str, optional, defaults to os.path.basename(sys.argv[0])[ -- -len(".py")]): Name of this experiment (by default is the file name without the extension name).
  • run_name (str, optional, defaults to "") — Name of this run.
  • seed (int, optional, defaults to 0) — Random seed.
  • log_with (Optional[Literal["wandb", "tensorboard"]], optional, defaults to None) — Log with either ‘wandb’ or ‘tensorboard’, check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
  • tracker_kwargs (Dict, optional, defaults to {}) — Keyword arguments for the tracker (e.g. wandb_project).
  • accelerator_kwargs (Dict, optional, defaults to {}) — Keyword arguments for the accelerator.
  • project_kwargs (Dict, optional, defaults to {}) — Keyword arguments for the accelerator project config (e.g. logging_dir).
  • tracker_project_name (str, optional, defaults to "trl") — Name of project to use for tracking.
  • logdir (str, optional, defaults to "logs") — Top-level logging directory for checkpoint saving.
  • num_epochs (int, optional, defaults to 100) — Number of epochs to train.
  • save_freq (int, optional, defaults to 1) — Number of epochs between saving model checkpoints.
  • num_checkpoint_limit (int, optional, defaults to 5) — Number of checkpoints to keep before overwriting old ones.
  • mixed_precision (str, optional, defaults to "fp16") — Mixed precision training.
  • allow_tf32 (bool, optional, defaults to True) — Allow tf32 on Ampere GPUs.
  • resume_from (str, optional, defaults to "") — Resume training from a checkpoint.
  • sample_num_steps (int, optional, defaults to 50) — Number of sampler inference steps.
  • sample_eta (float, optional, defaults to 1.0) — Eta parameter for the DDIM sampler.
  • sample_guidance_scale (float, optional, defaults to 5.0) — Classifier-free guidance weight.
  • sample_batch_size (int, optional, defaults to 1) — Batch size (per GPU) to use for sampling.
  • sample_num_batches_per_epoch (int, optional, defaults to 2) — Number of batches to sample per epoch.
  • train_batch_size (int, optional, defaults to 1) — Batch size (per GPU) to use for training.
  • train_use_8bit_adam (bool, optional, defaults to False) — Use 8bit Adam optimizer from bitsandbytes.
  • train_learning_rate (float, optional, defaults to 3e-4) — Learning rate.
  • train_adam_beta1 (float, optional, defaults to 0.9) — Adam beta1.
  • train_adam_beta2 (float, optional, defaults to 0.999) — Adam beta2.
  • train_adam_weight_decay (float, optional, defaults to 1e-4) — Adam weight decay.
  • train_adam_epsilon (float, optional, defaults to 1e-8) — Adam epsilon.
  • train_gradient_accumulation_steps (int, optional, defaults to 1) — Number of gradient accumulation steps.
  • train_max_grad_norm (float, optional, defaults to 1.0) — Maximum gradient norm for gradient clipping.
  • train_num_inner_epochs (int, optional, defaults to 1) — Number of inner epochs per outer epoch.
  • train_cfg (bool, optional, defaults to True) — Whether or not to use classifier-free guidance during training.
  • train_adv_clip_max (float, optional, defaults to 5.0) — Clip advantages to the range.
  • train_clip_range (float, optional, defaults to 1e-4) — PPO clip range.
  • train_timestep_fraction (float, optional, defaults to 1.0) — Fraction of timesteps to train on.
  • per_prompt_stat_tracking (bool, optional, defaults to False) — Whether to track statistics for each prompt separately.
  • per_prompt_stat_tracking_buffer_size (int, optional, defaults to 16) — Number of reward values to store in the buffer for each prompt.
  • per_prompt_stat_tracking_min_count (int, optional, defaults to 16) — Minimum number of reward values to store in the buffer.
  • async_reward_computation (bool, optional, defaults to False) — Whether to compute rewards asynchronously.
  • max_workers (int, optional, defaults to 2) — Maximum number of workers to use for async reward computation.
  • negative_prompts (Optional[str], optional, defaults to "") — Comma-separated list of prompts to use as negative examples.
  • push_to_hub (bool, optional, defaults to False) — Whether to push the final model checkpoint to the Hub.

Configuration class for the DDPOTrainer.

Using HfArgumentParser we can turn this class into argparse arguments that can be specified on the command line.

< > Update on GitHub