# Stable Diffusion text-to-image fine-tuning The [`train_text_to_image.py`](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) script shows how to fine-tune the stable diffusion model on your own dataset. The text-to-image fine-tuning script is experimental. It's easy to overfit and run into issues like catastrophic forgetting. We recommend to explore different hyperparameters to get the best results on your dataset. ## Running locally ### Installing the dependencies Before running the scripts, make sure to install the library's training dependencies: ```bash pip install git+https://github.com/huggingface/diffusers.git pip install -U -r requirements.txt ``` And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: ```bash accelerate config ``` You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). Run the following command to authenticate your token ```bash huggingface-cli login ``` If you have already cloned the repo, then you won't need to go through these steps. Instead, you can pass the path to your local checkout to the training script and it will be loaded from there. ### Hardware Requirements for Fine-tuning Using `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with more than 30GB of GPU memory. You can also use JAX / Flax for fine-tuning on TPUs or GPUs, see [below](#flax-jax-finetuning) for details. ### Fine-tuning Example The following script will launch a fine-tuning run using [Justin Pinkneys' captioned Pokemon dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions), available in Hugging Face Hub. ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" export dataset_name="lambdalabs/pokemon-blip-captions" accelerate launch train_text_to_image.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --dataset_name=$dataset_name \ --use_ema \ --resolution=512 --center_crop --random_flip \ --train_batch_size=1 \ --gradient_accumulation_steps=4 \ --gradient_checkpointing \ --mixed_precision="fp16" \ --max_train_steps=15000 \ --learning_rate=1e-05 \ --max_grad_norm=1 \ --lr_scheduler="constant" --lr_warmup_steps=0 \ --output_dir="sd-pokemon-model" ``` To run on your own training files you need to prepare the dataset according to the format required by `datasets`. You can upload your dataset to the Hub, or you can prepare a local folder with your files. [This documentation](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata) explains how to do it. You should modify the script if you wish to use custom loading logic. We have left pointers in the code in the appropriate places :) ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" export TRAIN_DIR="path_to_your_dataset" export OUTPUT_DIR="path_to_save_model" accelerate launch train_text_to_image.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$TRAIN_DIR \ --use_ema \ --resolution=512 --center_crop --random_flip \ --train_batch_size=1 \ --gradient_accumulation_steps=4 \ --gradient_checkpointing \ --mixed_precision="fp16" \ --max_train_steps=15000 \ --learning_rate=1e-05 \ --max_grad_norm=1 \ --lr_scheduler="constant" --lr_warmup_steps=0 \ --output_dir=${OUTPUT_DIR} ``` Once training is finished the model will be saved to the `OUTPUT_DIR` specified in the command. To load the fine-tuned model for inference, just pass that path to `StableDiffusionPipeline`: ```python from diffusers import StableDiffusionPipeline model_path = "path_to_saved_model" pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) pipe.to("cuda") image = pipe(prompt="yoda").images[0] image.save("yoda-pokemon.png") ``` ### Flax / JAX fine-tuning Thanks to [@duongna211](https://github.com/duongna21) it's possible to fine-tune Stable Diffusion using Flax! This is very efficient on TPU hardware but works great on GPUs too. You can use the [Flax training script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py) like this: ```Python export MODEL_NAME="runwayml/stable-diffusion-v1-5" export dataset_name="lambdalabs/pokemon-blip-captions" python train_text_to_image_flax.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --dataset_name=$dataset_name \ --resolution=512 --center_crop --random_flip \ --train_batch_size=1 \ --max_train_steps=15000 \ --learning_rate=1e-05 \ --max_grad_norm=1 \ --output_dir="sd-pokemon-model" ```