hungchiayu1
initial commit
ffead1e
|
raw
history blame
5.07 kB

Textual Inversion fine-tuning example

Textual inversion is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples. The textual_inversion.py script shows how to implement the training procedure and adapt it for stable diffusion.

Running on Colab

Colab for training Open In Colab

Colab for inference Open In Colab

Running locally with PyTorch

Installing the dependencies

Before running the scripts, make sure to install the library's training dependencies:

Important

To make sure you can successfully run the latest versions of the example scripts, we highly recommend installing from source and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:

git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .

Then cd in the example folder and run

pip install -r requirements.txt

And initialize an 🤗Accelerate environment with:

accelerate config

Cat toy example

You need to accept the model license before downloading or using the weights. In this example we'll use model version v1-5, so you'll need to visit its card, read the license and tick the checkbox if you agree.

You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to this section of the documentation.

Run the following command to authenticate your token

huggingface-cli login

If you have already cloned the repo, then you won't need to go through these steps.


Now let's get our dataset.Download 3-4 images from here and save them in a directory. This will be our training data.

And launch the training using

Note: Change the resolution to 768 if you are using the stable-diffusion-2 768x768 model.

export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export DATA_DIR="path-to-dir-containing-images"

accelerate launch textual_inversion.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$DATA_DIR \
  --learnable_property="object" \
  --placeholder_token="<cat-toy>" --initializer_token="toy" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --max_train_steps=3000 \
  --learning_rate=5.0e-04 --scale_lr \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --output_dir="textual_inversion_cat"

A full training run takes ~1 hour on one V100 GPU.

Inference

Once you have trained a model using above command, the inference can be done simply using the StableDiffusionPipeline. Make sure to include the placeholder_token in your prompt.

from diffusers import StableDiffusionPipeline

model_id = "path-to-your-trained-model"
pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")

prompt = "A <cat-toy> backpack"

image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]

image.save("cat-backpack.png")

Training with Flax/JAX

For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.

Before running the scripts, make sure to install the library's training dependencies:

pip install -U -r requirements_flax.txt
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export DATA_DIR="path-to-dir-containing-images"

python textual_inversion_flax.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$DATA_DIR \
  --learnable_property="object" \
  --placeholder_token="<cat-toy>" --initializer_token="toy" \
  --resolution=512 \
  --train_batch_size=1 \
  --max_train_steps=3000 \
  --learning_rate=5.0e-04 --scale_lr \
  --output_dir="textual_inversion_cat"

It should be at least 70% faster than the PyTorch script with the same configuration.

Training with xformers:

You can enable memory efficient attention by installing xFormers and padding the --enable_xformers_memory_efficient_attention argument to the script. This is not available with the Flax/JAX implementation.