diff --git a/000000000285.jpg b/000000000285.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7d760805c845723a9b2d8efb049ca65c385a9a54 Binary files /dev/null and b/000000000285.jpg differ diff --git a/000000000724.jpg b/000000000724.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2a17e0c6ee400dcba762c4d56dea03d7e124b9c5 Binary files /dev/null and b/000000000724.jpg differ diff --git a/000000007991.jpg b/000000007991.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d5321099964d7444cf034bf8ea819baf25103572 Binary files /dev/null and b/000000007991.jpg differ diff --git a/000000018837.jpg b/000000018837.jpg new file mode 100644 index 0000000000000000000000000000000000000000..51b7b34992cb507620356020b070f61ad5a685c1 Binary files /dev/null and b/000000018837.jpg differ diff --git a/000000122962.jpg b/000000122962.jpg new file mode 100644 index 0000000000000000000000000000000000000000..75688dca25459befea73ecbd8ccc2ac2817c6f4f Binary files /dev/null and b/000000122962.jpg differ diff --git a/000000295478.jpg b/000000295478.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9bbcefa4d25d1548e46e7342730f420deb0a52e3 Binary files /dev/null and b/000000295478.jpg differ diff --git a/ORIGINAL_README.md b/ORIGINAL_README.md new file mode 100644 index 0000000000000000000000000000000000000000..39cbbfa1909998e29b73745b1dc98f6dbbeedf68 --- /dev/null +++ b/ORIGINAL_README.md @@ -0,0 +1,128 @@ +# Text-Guided-Image-Colorization + +This project utilizes the power of **Stable Diffusion (SDXL/SDXL-Light)** and the **BLIP (Bootstrapping Language-Image Pre-training)** captioning model to provide an interactive image colorization experience. Users can influence the generated colors of objects within images, making the colorization process more personalized and creative. + +## Table of Contents + - [Features](#features) + - [Installation](#installation) + - [Quick Start](#quick-start) + - [Dataset Usage](#dataset-usage) + - [Training](#training) + - [Evaluation](#evaluation) + - [Results](#results) + - [License](#license) + +## Features + +- **Interactive Colorization**: Users can specify desired colors for different objects in the image. +- **ControlNet Approach**: Enhanced colorization capabilities through retraining with ControlNet, allowing SDXL to better adapt to the image colorization task. +- **High-Quality Outputs**: Leverage the latest advancements in diffusion models to generate vibrant and realistic colorizations. +- **User-Friendly Interface**: Easy-to-use interface for seamless interaction with the model. + +## Installation + +To set up the project locally, follow these steps: + +1. **Clone the Repository**: + + ```bash + git clone https://github.com/nick8592/text-guided-image-colorization.git + cd text-guided-image-colorization + ``` + +2. **Install Dependencies**: + Make sure you have Python 3.7 or higher installed. Then, install the required packages: + + ```bash + pip install -r requirements.txt + ``` + Install `torch` and `torchvision` matching your CUDA version: + ```bash + pip install torch torchvision --index-url https://download.pytorch.org/whl/cuXXX + ``` + Replace `XXX` with your CUDA version (e.g., `118` for CUDA 11.8). For more info, see [PyTorch Get Started](https://pytorch.org/get-started/locally/). + + +3. **Download Pre-trained Models**: + | Models | Hugging Face (Recommand) | Other | + |:---:|:---:|:---:| + |SDXL-Lightning Caption|[link](https://huggingface.co/nickpai/sdxl_light_caption_output)|[link](https://gofile.me/7uE8s/FlEhfpWPw) (2kNJfV)| + |SDXL-Lightning Custom Caption (Recommand)|[link](https://huggingface.co/nickpai/sdxl_light_custom_caption_output)|[link](https://gofile.me/7uE8s/AKmRq5sLR) (KW7Fpi)| + + + ```bash + text-guided-image-colorization/sdxl_light_caption_output + └── checkpoint-30000 + ├── controlnet + │ ├── diffusion_pytorch_model.safetensors + │ └── config.json + ├── optimizer.bin + ├── random_states_0.pkl + ├── scaler.pt + └── scheduler.bin + ``` + +## Quick Start + +1. Run the `gradio_ui.py` script: + +```bash +python gradio_ui.py +``` + +2. Open the provided URL in your web browser to access the Gradio-based user interface. + +3. Upload an image and use the interface to control the colors of specific objects in the image. But still the model can generate images without a specific prompt. + +4. The model will generate a colorized version of the image based on your input (or automatic). See the [demo video](https://x.com/weichenpai/status/1829513077588631987). +![Gradio UI](images/gradio_ui.png) + + +## Dataset Usage + +You can find more details about the dataset usage in the [Dataset-for-Image-Colorization](https://github.com/nick8592/Dataset-for-Image-Colorization). + +## Training + +For training, you can use one of the following scripts: + +- `train_controlnet.sh`: Trains a model using [Stable Diffusion v2](https://huggingface.co/stabilityai/stable-diffusion-2-1) +- `train_controlnet_sdxl.sh`: Trains a model using [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) +- `train_controlnet_sdxl_light.sh`: Trains a model using [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) + +Although the training code for SDXL is provided, due to a lack of GPU resources, I wasn't able to train the model by myself. Therefore, there might be some errors when you try to train the model. + +## Evaluation + +For evaluation, you can use one of the following scripts: + +- `eval_controlnet.sh`: Evaluates the model using [Stable Diffusion v2](https://huggingface.co/stabilityai/stable-diffusion-2-1) for a folder of images. +- `eval_controlnet_sdxl_light.sh`: Evaluates the model using [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) for a folder of images. +- `eval_controlnet_sdxl_light_single.sh`: Evaluates the model using [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) for a single image. + +## Results +### Prompt-Guided +| Caption | Condition 1 | Condition 2 | Condition 3 | +|:---:|:---:|:---:|:---:| +| ![000000022935_gray.jpg](images/000000022935_gray.jpg) | ![000000022935_green_shirt_on_right_girl.jpeg](images/000000022935_green_shirt_on_right_girl.jpeg) | ![000000022935_purple_shirt_on_right_girl.jpeg](images/000000022935_purple_shirt_on_right_girl.jpeg) |![000000022935_red_shirt_on_right_girl.jpeg](images/000000022935_red_shirt_on_right_girl.jpeg) | +| a photography of a woman in a soccer uniform kicking a soccer ball | + "green shirt"| + "purple shirt" | + "red shirt" | +| ![000000041633_gray.jpg](images/000000041633_gray.jpg) | ![000000041633_bright_red_car.jpeg](images/000000041633_bright_red_car.jpeg) | ![000000041633_dark_blue_car.jpeg](images/000000041633_dark_blue_car.jpeg) |![000000041633_black_car.jpeg](images/000000041633_black_car.jpeg) | +| a photography of a photo of a truck | + "bright red car"| + "dark blue car" | + "black car" | +| ![000000286708_gray.jpg](images/000000286708_gray.jpg) | ![000000286708_orange_hat.jpeg](images/000000286708_orange_hat.jpeg) | ![000000286708_pink_hat.jpeg](images/000000286708_pink_hat.jpeg) |![000000286708_yellow_hat.jpeg](images/000000286708_yellow_hat.jpeg) | +| a photography of a cat wearing a hat on his head | + "orange hat"| + "pink hat" | + "yellow hat" | + +### Prompt-Free +Ground truth images are provided solely for reference purpose in the image colorization task. +| Grayscale Image | Colorized Result | Ground Truth | +|:---:|:---:|:---:| +| ![000000025560_gray.jpg](images/000000025560_gray.jpg) | ![000000025560_color.jpg](images/000000025560_color.jpg) | ![000000025560_gt.jpg](images/000000025560_gt.jpg) | +| ![000000065736_gray.jpg](images/000000065736_gray.jpg) | ![000000065736_color.jpg](images/000000065736_color.jpg) | ![000000065736_gt.jpg](images/000000065736_gt.jpg) | +| ![000000091779_gray.jpg](images/000000091779_gray.jpg) | ![000000091779_color.jpg](images/000000091779_color.jpg) | ![000000091779_gt.jpg](images/000000091779_gt.jpg) | +| ![000000092177_gray.jpg](images/000000092177_gray.jpg) | ![000000092177_color.jpg](images/000000092177_color.jpg) | ![000000092177_gt.jpg](images/000000092177_gt.jpg) | +| ![000000166426_gray.jpg](images/000000166426_gray.jpg) | ![000000166426_color.jpg](images/000000166426_color.jpg) | ![000000025560_gt.jpg](images/000000166426_gt.jpg) | + + + +## License + +This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for more details. diff --git a/eval_controlnet.py b/eval_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0dc7037b955fd257db2462e3e9185145028bea04 --- /dev/null +++ b/eval_controlnet.py @@ -0,0 +1,148 @@ +import os +import time +import torch +import shutil +import argparse +import numpy as np + +from tqdm import tqdm +from PIL import Image +from datasets import load_dataset +from diffusers.utils import load_image +from diffusers import StableDiffusionControlNetPipeline, ControlNetModel + +# Define the function to parse arguments +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNet evaluation script.") + + parser.add_argument("--model_dir", type=str, default="sd_v2_caption_free_output/checkpoint-22500", + help="Directory of the model checkpoint") + parser.add_argument("--model_id", type=str, default="stabilityai/stable-diffusion-2-base", + help="ID of the model (Tested with runwayml/stable-diffusion-v1-5 and stabilityai/stable-diffusion-2-base)") + parser.add_argument("--dataset", type=str, default="nickpai/coco2017-colorization", + help="Dataset used") + parser.add_argument("--revision", type=str, default="caption-free", + choices=["main", "caption-free"], + help="Revision option (main/caption-free)") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + return args + +def apply_color(image, color_map): + # Convert input images to LAB color space + image_lab = image.convert('LAB') + color_map_lab = color_map.convert('LAB') + + # Split LAB channels + l, a, b = image_lab.split() + _, a_map, b_map = color_map_lab.split() + + # Merge LAB channels with color map + merged_lab = Image.merge('LAB', (l, a_map, b_map)) + + # Convert merged LAB image back to RGB color space + result_rgb = merged_lab.convert('RGB') + + return result_rgb + +def main(args): + generator = torch.manual_seed(0) + + # MODEL_DIR = "sd_v2_caption_free_output/checkpoint-22500" + # # MODEL_ID="runwayml/stable-diffusion-v1-5" + # MODEL_ID="stabilityai/stable-diffusion-2-base" + # DATASET = "nickpai/coco2017-colorization" + # REVISION = "caption-free" # option: main/caption-free + + # Path to the eval_results folder + eval_results_folder = os.path.join(args.model_dir, "results") + + # Remove eval_results folder if it exists + if os.path.exists(eval_results_folder): + shutil.rmtree(eval_results_folder) + + # Create directory for eval_results + os.makedirs(eval_results_folder) + + # Create subfolders for compare and colorized images + compare_folder = os.path.join(eval_results_folder, "compare") + colorized_folder = os.path.join(eval_results_folder, "colorized") + os.makedirs(compare_folder) + os.makedirs(colorized_folder) + + # Load the validation split of the colorization dataset + val_dataset = load_dataset(args.dataset, split="validation", revision=args.revision) + + controlnet = ControlNetModel.from_pretrained(f"{args.model_dir}/controlnet", torch_dtype=torch.float16) + pipe = StableDiffusionControlNetPipeline.from_pretrained( + args.model_id, controlnet=controlnet, torch_dtype=torch.float16 + ).to("cuda") + + pipe.safety_checker = None + + # Counter for processed images + processed_images = 0 + + # Record start time + start_time = time.time() + + # Iterate through the validation dataset + for example in tqdm(val_dataset, desc="Processing Images"): + image_path = example["file_name"] + + prompt = [] + for caption in example["captions"]: + if isinstance(caption, str): + prompt.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + prompt.append(caption[0]) + else: + raise ValueError( + f"Caption column `captions` should contain either strings or lists of strings." + ) + + # Generate image + ground_truth_image = load_image(image_path).resize((512, 512)) + control_image = load_image(image_path).convert("L").convert("RGB").resize((512, 512)) + image = pipe(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0] + + # Apply color mapping + image = apply_color(ground_truth_image, image) + + # Concatenate images into a row + row_image = np.hstack((np.array(control_image), np.array(image), np.array(ground_truth_image))) + row_image = Image.fromarray(row_image) + + # Save row image in the compare folder + compare_output_path = os.path.join(compare_folder, f"{image_path.split('/')[-1]}") + row_image.save(compare_output_path) + + # Save colorized image in the colorized folder + colorized_output_path = os.path.join(colorized_folder, f"{image_path.split('/')[-1]}") + image.save(colorized_output_path) + + # Increment processed images counter + processed_images += 1 + + # Record end time + end_time = time.time() + + # Calculate total time taken + total_time = end_time - start_time + + # Calculate FPS + fps = processed_images / total_time + + print("All images processed.") + print(f"Total time taken: {total_time:.2f} seconds") + print(f"FPS: {fps:.2f}") + +# Entry point of the script +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/eval_controlnet.sh b/eval_controlnet.sh new file mode 100644 index 0000000000000000000000000000000000000000..f64545bf383e566b616643760a42f137c7bd40fb --- /dev/null +++ b/eval_controlnet.sh @@ -0,0 +1,19 @@ +# Define default values for parameters + +# # sdv2 with BCE loss +# MODEL_DIR="sd_v2_caption_bce_output/checkpoint-22500" +# MODEL_ID="stabilityai/stable-diffusion-2-base" +# DATASET="nickpai/coco2017-colorization" +# REVISION="main" + +# sdv2 with kl loss +MODEL_DIR="sd_v2_caption_kl_output/checkpoint-22500" +MODEL_ID="stabilityai/stable-diffusion-2-base" +DATASET="nickpai/coco2017-colorization" +REVISION="main" + +accelerate launch eval_controlnet.py \ + --model_dir=$MODEL_DIR \ + --model_id=$MODEL_ID \ + --dataset=$DATASET \ + --revision=$REVISION \ No newline at end of file diff --git a/eval_controlnet_sdxl_light.py b/eval_controlnet_sdxl_light.py new file mode 100644 index 0000000000000000000000000000000000000000..300c327907e2dc16af43c728a519450edf9a02ab --- /dev/null +++ b/eval_controlnet_sdxl_light.py @@ -0,0 +1,284 @@ +import os +import time +import torch +import shutil +import argparse +import numpy as np + +from tqdm import tqdm +from PIL import Image +from datasets import load_dataset +from accelerate import Accelerator +from diffusers.utils import load_image +from diffusers import ( + AutoencoderKL, + StableDiffusionXLControlNetPipeline, + ControlNetModel, + UNet2DConditionModel, +) +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file + +# Define the function to parse arguments +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNet evaluation script.") + + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained controlnet model.", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + required=True, + help="Path to output results.", + ) + parser.add_argument( + "--dataset", + type=str, + default="nickpai/coco2017-colorization", + help="Dataset used" + ) + parser.add_argument( + "--dataset_revision", + type=str, + default="caption-free", + choices=["main", "caption-free", "custom-caption"], + help="Revision option (main/caption-free/custom-caption)" + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=8, + help="1-step, 2-step, 4-step, or 8-step distilled models" + ) + parser.add_argument( + "--repo", + type=str, + default="ByteDance/SDXL-Lightning", + required=True, + help="Repository from huggingface.co", + ) + parser.add_argument( + "--ckpt", + type=str, + default="sdxl_lightning_4step_unet.safetensors", + required=True, + help="Available checkpoints from the repository", + ) + parser.add_argument( + "--negative_prompt", + action="store_true", + help="The prompt or prompts not to guide the image generation", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + return args + +def apply_color(image, color_map): + # Convert input images to LAB color space + image_lab = image.convert('LAB') + color_map_lab = color_map.convert('LAB') + + # Split LAB channels + l, a, b = image_lab.split() + _, a_map, b_map = color_map_lab.split() + + # Merge LAB channels with color map + merged_lab = Image.merge('LAB', (l, a_map, b_map)) + + # Convert merged LAB image back to RGB color space + result_rgb = merged_lab.convert('RGB') + + return result_rgb + +def main(args): + generator = torch.manual_seed(0) + + # Path to the eval_results folder + eval_results_folder = os.path.join(args.output_dir, "results") + + # Remove eval_results folder if it exists + if os.path.exists(eval_results_folder): + shutil.rmtree(eval_results_folder) + + # Create directory for eval_results + os.makedirs(eval_results_folder) + + # Create subfolders for compare and colorized images + compare_folder = os.path.join(eval_results_folder, "compare") + colorized_folder = os.path.join(eval_results_folder, "colorized") + os.makedirs(compare_folder) + os.makedirs(colorized_folder) + + # Load the validation split of the colorization dataset + val_dataset = load_dataset(args.dataset, split="validation", revision=args.dataset_revision) + + accelerator = Accelerator( + mixed_precision=args.mixed_precision, + ) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + ) + unet = UNet2DConditionModel.from_config( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + variant=args.variant, + ) + unet.load_state_dict(load_file(hf_hub_download(args.repo, args.ckpt))) + + # Move vae, unet and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + if args.pretrained_vae_model_name_or_path is not None: + vae.to(accelerator.device, dtype=weight_dtype) + else: + vae.to(accelerator.device, dtype=torch.float32) + unet.to(accelerator.device, dtype=weight_dtype) + + controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, torch_dtype=weight_dtype) + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + unet=unet, + controlnet=controlnet, + ) + pipe.to(accelerator.device, dtype=weight_dtype) + + # Prepare everything with our `accelerator`. + pipe, val_dataset = accelerator.prepare(pipe, val_dataset) + + pipe.safety_checker = None + + # Counter for processed images + processed_images = 0 + + # Record start time + start_time = time.time() + + # Iterate through the validation dataset + for example in tqdm(val_dataset, desc="Processing Images"): + image_path = example["file_name"] + + prompt = [] + for caption in example["captions"]: + if isinstance(caption, str): + prompt.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + prompt.append(caption[0]) + else: + raise ValueError( + f"Caption column `captions` should contain either strings or lists of strings." + ) + + negative_prompt = None + if args.negative_prompt: + negative_prompt = [ + "low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate" + ] + + # Generate image + ground_truth_image = load_image(image_path).resize((512, 512)) + control_image = load_image(image_path).convert("L").convert("RGB").resize((512, 512)) + image = pipe(prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=args.num_inference_steps, + generator=generator, + image=control_image).images[0] + + # Apply color mapping + image = apply_color(ground_truth_image, image) + + # Concatenate images into a row + row_image = np.hstack((np.array(control_image), np.array(image), np.array(ground_truth_image))) + row_image = Image.fromarray(row_image) + + # Save row image in the compare folder + compare_output_path = os.path.join(compare_folder, f"{image_path.split('/')[-1]}") + row_image.save(compare_output_path) + + # Save colorized image in the colorized folder + colorized_output_path = os.path.join(colorized_folder, f"{image_path.split('/')[-1]}") + image.save(colorized_output_path) + + # Increment processed images counter + processed_images += 1 + + # Record end time + end_time = time.time() + + # Calculate total time taken + total_time = end_time - start_time + + # Calculate FPS + fps = processed_images / total_time + + print("All images processed.") + print(f"Total time taken: {total_time:.2f} seconds") + print(f"FPS: {fps:.2f}") + +# Entry point of the script +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/eval_controlnet_sdxl_light.sh b/eval_controlnet_sdxl_light.sh new file mode 100644 index 0000000000000000000000000000000000000000..72739952c18e4d3eaf8ef4a5fa598e5d939f6c27 --- /dev/null +++ b/eval_controlnet_sdxl_light.sh @@ -0,0 +1,44 @@ +# Define default values for parameters + +# # sdxl light without negative prompt +# export BASE_MODEL="stabilityai/stable-diffusion-xl-base-1.0" +# export REPO="ByteDance/SDXL-Lightning" +# export INFERENCE_STEP=8 +# export CKPT="sdxl_lightning_8step_unet.safetensors" # caution!!! ckpt's "N"step must match with inference_step +# export CONTROLNET_MODEL="sdxl_light_custom_caption_output/checkpoint-12500/controlnet" +# export DATASET="nickpai/coco2017-colorization" +# export DATSET_REVISION="custom-caption" +# export OUTPUT_DIR="sdxl_light_custom_caption_output/checkpoint-12500" + +# accelerate launch eval_controlnet_sdxl_light.py \ +# --pretrained_model_name_or_path=$BASE_MODEL \ +# --repo=$REPO \ +# --ckpt=$CKPT \ +# --num_inference_steps=$INFERENCE_STEP \ +# --controlnet_model_name_or_path=$CONTROLNET_MODEL \ +# --dataset=$DATASET \ +# --dataset_revision=$DATSET_REVISION \ +# --mixed_precision="fp16" \ +# --output_dir=$OUTPUT_DIR + +# sdxl light with negative prompt +export BASE_MODEL="stabilityai/stable-diffusion-xl-base-1.0" +export REPO="ByteDance/SDXL-Lightning" +export INFERENCE_STEP=8 +export CKPT="sdxl_lightning_8step_unet.safetensors" # caution!!! ckpt's "N"step must match with inference_step +export CONTROLNET_MODEL="sdxl_light_caption_output/checkpoint-22500/controlnet" +export DATASET="nickpai/coco2017-colorization" +export DATSET_REVISION="custom-caption" +export OUTPUT_DIR="sdxl_light_caption_output/checkpoint-22500" + +accelerate launch eval_controlnet_sdxl_light.py \ + --pretrained_model_name_or_path=$BASE_MODEL \ + --repo=$REPO \ + --ckpt=$CKPT \ + --num_inference_steps=$INFERENCE_STEP \ + --controlnet_model_name_or_path=$CONTROLNET_MODEL \ + --dataset=$DATASET \ + --dataset_revision=$DATSET_REVISION \ + --mixed_precision="fp16" \ + --output_dir=$OUTPUT_DIR \ + --negative_prompt \ No newline at end of file diff --git a/eval_controlnet_sdxl_light_single.py b/eval_controlnet_sdxl_light_single.py new file mode 100644 index 0000000000000000000000000000000000000000..74e7cf1c443c46b6bad49ab28f915e3ac6fd83ed --- /dev/null +++ b/eval_controlnet_sdxl_light_single.py @@ -0,0 +1,390 @@ +import os +import PIL +import time +import torch +import argparse + +from typing import Optional, Union +from accelerate import Accelerator +from diffusers import ( + AutoencoderKL, + StableDiffusionXLControlNetPipeline, + ControlNetModel, + UNet2DConditionModel, +) +from transformers import ( + BlipProcessor, BlipForConditionalGeneration, + VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer +) +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file + +# Define the function to parse arguments +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNet evaluation script.") + parser.add_argument( + "--image_path", + type=str, + default="example/legacy_images/Hollywood-Sign.jpg", + required=True, + help="Path to the image", + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained controlnet model.", + ) + parser.add_argument( + "--caption_model_name", + type=str, + default="blip-image-captioning-large", + choices=["blip-image-captioning-large", "blip-image-captioning-base"], + help="Path to pretrained controlnet model.", + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=8, + help="1-step, 2-step, 4-step, or 8-step distilled models" + ) + parser.add_argument( + "--repo", + type=str, + default="ByteDance/SDXL-Lightning", + required=True, + help="Repository from huggingface.co", + ) + parser.add_argument( + "--ckpt", + type=str, + default="sdxl_lightning_4step_unet.safetensors", + required=True, + help="Available checkpoints from the repository", + ) + parser.add_argument( + "--seed", + type=int, + default=123, + help="Random seeds" + ) + parser.add_argument( + "--positive_prompt", + type=str, + help="Text for positive prompt", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate", + help="Text for negative prompt", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + return args + +def apply_color(image, color_map): + # Convert input images to LAB color space + image_lab = image.convert('LAB') + color_map_lab = color_map.convert('LAB') + + # Split LAB channels + l, a, b = image_lab.split() + _, a_map, b_map = color_map_lab.split() + + # Merge LAB channels with color map + merged_lab = PIL.Image.merge('LAB', (l, a_map, b_map)) + + # Convert merged LAB image back to RGB color space + result_rgb = merged_lab.convert('RGB') + + return result_rgb + +def remove_unlikely_words(prompt: str) -> str: + """ + Removes unlikely words from a prompt. + + Args: + prompt: The text prompt to be cleaned. + + Returns: + The cleaned prompt with unlikely words removed. + """ + unlikely_words = [] + + a1_list = [f'{i}s' for i in range(1900, 2000)] + a2_list = [f'{i}' for i in range(1900, 2000)] + a3_list = [f'year {i}' for i in range(1900, 2000)] + a4_list = [f'circa {i}' for i in range(1900, 2000)] + b1_list = [f"{year[0]} {year[1]} {year[2]} {year[3]} s" for year in a1_list] + b2_list = [f"{year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] + b3_list = [f"year {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] + b4_list = [f"circa {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] + + words_list = [ + "black and white,", "black and white", "black & white,", "black & white", "circa", + "balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,", + "black - and - white photography,", "monochrome bw,", "black white,", "black an white,", + "grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo", + "back and white", "back and white,", "monochrome contrast", "monochrome", "grainy", + "grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w", + "grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo", + "b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,", + "black-and-white photo,", "black-and-white photo", "black - and - white photography", + "b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic", + "blurry photo,", "blurry,", "blurry photography,", "monochromatic photo", + "black - and - white photograph,", "black - and - white photograph", "black on white,", + "black on white", "black-and-white", "historical image,", "historical picture,", + "historical photo,", "historical photograph,", "archival photo,", "taken in the early", + "taken in the late", "taken in the", "historic photograph,", "restored,", "restored", + "historical photo", "historical setting,", + "historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated", + "taken in", "shot on leica", "shot on leica sl2", "sl2", + "taken with a leica camera", "taken with a leica camera", "leica sl2", "leica", "setting", + "overcast day", "overcast weather", "slight overcast", "overcast", + "picture taken in", "photo taken in", + ", photo", ", photo", ", photo", ", photo", ", photograph", + ",,", ",,,", ",,,,", " ,", " ,", " ,", " ,", + ] + + unlikely_words.extend(a1_list) + unlikely_words.extend(a2_list) + unlikely_words.extend(a3_list) + unlikely_words.extend(a4_list) + unlikely_words.extend(b1_list) + unlikely_words.extend(b2_list) + unlikely_words.extend(b3_list) + unlikely_words.extend(b4_list) + unlikely_words.extend(words_list) + + for word in unlikely_words: + prompt = prompt.replace(word, "") + return prompt + +def blip_image_captioning(image: PIL.Image.Image, + model_backbone: str, + weight_dtype: type, + device: str, + conditional: bool) -> str: + # https://huggingface.co/Salesforce/blip-image-captioning-large + # https://huggingface.co/Salesforce/blip-image-captioning-base + if weight_dtype == torch.bfloat16: # in case model might not accept bfloat16 data type + weight_dtype = torch.float16 + + processor = BlipProcessor.from_pretrained(f"Salesforce/{model_backbone}") + model = BlipForConditionalGeneration.from_pretrained( + f"Salesforce/{model_backbone}", torch_dtype=weight_dtype).to(device) + + valid_backbones = ["blip-image-captioning-large", "blip-image-captioning-base"] + if model_backbone not in valid_backbones: + raise ValueError(f"Invalid model backbone '{model_backbone}'. \ + Valid options are: {', '.join(valid_backbones)}") + + if conditional: + text = "a photography of" + inputs = processor(image, text, return_tensors="pt").to(device, weight_dtype) + else: + inputs = processor(image, return_tensors="pt").to(device) + out = model.generate(**inputs) + caption = processor.decode(out[0], skip_special_tokens=True) + return caption + +import matplotlib.pyplot as plt + +def display_images(input_image, output_image, ground_truth): + """ + Displays a grid of input, output, ground truth images with a caption at the bottom. + + Args: + input_image: A grayscale image as a NumPy array. + output_image: A grayscale image (result) as a NumPy array. + ground_truth: A grayscale image (ground truth) as a NumPy array. + """ + fig, axes = plt.subplots(1, 3, figsize=(20, 8)) + + axes[0].imshow(input_image, cmap='gray') + axes[0].set_title('Input') + axes[0].axis('off') + + axes[1].imshow(output_image) + axes[1].set_title('Output') + axes[1].axis('off') + + axes[2].imshow(ground_truth) + axes[2].set_title('Ground Truth') + axes[2].axis('off') + + plt.tight_layout() + plt.show() + +# Define a function to process the image with the loaded model +def process_image(image_path: str, + controlnet_model_name_or_path: str, + caption_model_name: str, + positive_prompt: Optional[str], + negative_prompt: Optional[str], + seed: int, + num_inference_steps: int, + mixed_precision: str, + pretrained_model_name_or_path: str, + pretrained_vae_model_name_or_path: Optional[str], + revision: Optional[str], + variant: Optional[str], + repo: str, + ckpt: str,) -> PIL.Image.Image: + # Seed + generator = torch.manual_seed(seed) + + # Accelerator Setting + accelerator = Accelerator( + mixed_precision=mixed_precision, + ) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + vae_path = ( + pretrained_model_name_or_path + if pretrained_vae_model_name_or_path is None + else pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if pretrained_vae_model_name_or_path is None else None, + revision=revision, + variant=variant, + ) + unet = UNet2DConditionModel.from_config( + pretrained_model_name_or_path, + subfolder="unet", + revision=revision, + variant=variant, + ) + unet.load_state_dict(load_file(hf_hub_download(repo, ckpt))) + + # Move vae, unet and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + if pretrained_vae_model_name_or_path is not None: + vae.to(accelerator.device, dtype=weight_dtype) + else: + vae.to(accelerator.device, dtype=torch.float32) + unet.to(accelerator.device, dtype=weight_dtype) + + controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path, torch_dtype=weight_dtype) + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + pretrained_model_name_or_path, + vae=vae, + unet=unet, + controlnet=controlnet, + ) + pipe.to(accelerator.device, dtype=weight_dtype) + + image = PIL.Image.open(image_path) + + # Prepare everything with our `accelerator`. + pipe, image = accelerator.prepare(pipe, image) + pipe.safety_checker = None + + # Convert image into grayscale + original_size = image.size + control_image = image.convert("L").convert("RGB").resize((512, 512)) + + # Image captioning + if caption_model_name == "blip-image-captioning-large" or "blip-image-captioning-base": + caption = blip_image_captioning(control_image, caption_model_name, + weight_dtype, accelerator.device, conditional=True) + # elif caption_model_name == "ViT-L-14/openai" or "ViT-H-14/laion2b_s32b_b79k": + # caption = clip_image_captioning(control_image, caption_model_name, accelerator.device) + # elif caption_model_name == "vit-gpt2-image-captioning": + # caption = vit_gpt2_image_captioning(control_image, accelerator.device) + caption = remove_unlikely_words(caption) + + print("================================================================") + print(f"Positive prompt: \n>>> {positive_prompt}") + print(f"Negative prompt: \n>>> {negative_prompt}") + print(f"Caption results: \n>>> {caption}") + print("================================================================") + + # Combine positive prompt and captioning result + prompt = [positive_prompt + ", " + caption] + + # Image colorization + image = pipe(prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=num_inference_steps, + generator=generator, + image=control_image).images[0] + + # Apply color mapping + result_image = apply_color(control_image, image) + result_image = result_image.resize(original_size) + return result_image, caption + +def main(args): + output_image, output_caption = process_image(image_path=args.image_path, + controlnet_model_name_or_path=args.controlnet_model_name_or_path, + caption_model_name=args.caption_model_name, + positive_prompt=args.positive_prompt, + negative_prompt=args.negative_prompt, + seed=args.seed, + num_inference_steps=args.num_inference_steps, + mixed_precision=args.mixed_precision, + pretrained_model_name_or_path=args.pretrained_model_name_or_path, + pretrained_vae_model_name_or_path=args.pretrained_vae_model_name_or_path, + revision=args.revision, + variant=args.variant, + repo=args.repo, + ckpt=args.ckpt,) + input_image = PIL.Image.open(args.image_path) + display_images(input_image.convert("L"), output_image, input_image) + return output_image, output_caption + +# Entry point of the script +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/eval_controlnet_sdxl_light_single.sh b/eval_controlnet_sdxl_light_single.sh new file mode 100644 index 0000000000000000000000000000000000000000..2db9994b3ceb39558a54cd99e11d5ec9ce24a90c --- /dev/null +++ b/eval_controlnet_sdxl_light_single.sh @@ -0,0 +1,20 @@ +# sdxl light for single image +export BASE_MODEL="stabilityai/stable-diffusion-xl-base-1.0" +export REPO="ByteDance/SDXL-Lightning" +export INFERENCE_STEP=8 +export CKPT="sdxl_lightning_8step_unet.safetensors" # caution!!! ckpt's "N"step must match with inference_step +export CONTROLNET_MODEL="sdxl_light_caption_output/checkpoint-30000/controlnet" +export CAPTION_MODEL="blip-image-captioning-large" +export IMAGE_PATH="example/legacy_images/Hollywood-Sign.jpg" +# export POSITIVE_PROMPT="blue shirt" + +accelerate launch eval_controlnet_sdxl_light_single.py \ + --pretrained_model_name_or_path=$BASE_MODEL \ + --repo=$REPO \ + --ckpt=$CKPT \ + --num_inference_steps=$INFERENCE_STEP \ + --controlnet_model_name_or_path=$CONTROLNET_MODEL \ + --caption_model_name=$CAPTION_MODEL \ + --mixed_precision="fp16" \ + --image_path=$IMAGE_PATH \ + --positive_prompt="red car" \ No newline at end of file diff --git a/example/UUColor_results/Hollywood-Sign.jpeg b/example/UUColor_results/Hollywood-Sign.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..36d167423ba1e6349c38c53b9c7b71d95d74dc2b Binary files /dev/null and b/example/UUColor_results/Hollywood-Sign.jpeg differ diff --git a/example/legacy_images/Big-Ben-vintage.jpg b/example/legacy_images/Big-Ben-vintage.jpg new file mode 100644 index 0000000000000000000000000000000000000000..59d8e63dbc81d6b9855ed66711ffc22d024a683a Binary files /dev/null and b/example/legacy_images/Big-Ben-vintage.jpg differ diff --git a/example/legacy_images/Central-Park.jpg b/example/legacy_images/Central-Park.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5a94e2c3414fa89f851eda8dcb8fef9dcb3f5984 Binary files /dev/null and b/example/legacy_images/Central-Park.jpg differ diff --git a/example/legacy_images/Hollywood-Sign.jpg b/example/legacy_images/Hollywood-Sign.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c9cf818b1808ae7af71c0bffba12846a719d1432 Binary files /dev/null and b/example/legacy_images/Hollywood-Sign.jpg differ diff --git a/example/legacy_images/Little-Mermaid.jpg b/example/legacy_images/Little-Mermaid.jpg new file mode 100644 index 0000000000000000000000000000000000000000..63f2ac5784dbbe0b140b6892d6587f4b991e8737 Binary files /dev/null and b/example/legacy_images/Little-Mermaid.jpg differ diff --git a/example/legacy_images/Migrant-Mother.jpg b/example/legacy_images/Migrant-Mother.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1d2feec89fba07f2bef3c1640016fc5f1c715863 Binary files /dev/null and b/example/legacy_images/Migrant-Mother.jpg differ diff --git a/example/legacy_images/Mount-Everest.jpg b/example/legacy_images/Mount-Everest.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d535f1f6bcb2bb5b835990228da61b8f323d4ef2 Binary files /dev/null and b/example/legacy_images/Mount-Everest.jpg differ diff --git a/example/legacy_images/Tower-of-Pisa.jpg b/example/legacy_images/Tower-of-Pisa.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9fbd1e48716544231f6eadd0c03b3773b217aae3 Binary files /dev/null and b/example/legacy_images/Tower-of-Pisa.jpg differ diff --git a/example/legacy_images/Wasatch-Mountains-Summit-County-Utah.jpg b/example/legacy_images/Wasatch-Mountains-Summit-County-Utah.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5fa0ca087c1e37bb232410a9778867ab20202b8a Binary files /dev/null and b/example/legacy_images/Wasatch-Mountains-Summit-County-Utah.jpg differ diff --git a/gradio_ui.py b/gradio_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..1b2bb32c22e9a28f4d573d85cf38b137cec243bd --- /dev/null +++ b/gradio_ui.py @@ -0,0 +1,356 @@ +import PIL +import torch +import subprocess +import gradio as gr + +from typing import Optional +from accelerate import Accelerator +from diffusers import ( + AutoencoderKL, + StableDiffusionXLControlNetPipeline, + ControlNetModel, + UNet2DConditionModel, +) +from transformers import ( + BlipProcessor, BlipForConditionalGeneration, + VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer +) +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file +from clip_interrogator import Interrogator, Config, list_clip_models + +def apply_color(image: PIL.Image.Image, color_map: PIL.Image.Image) -> PIL.Image.Image: + # Convert input images to LAB color space + image_lab = image.convert('LAB') + color_map_lab = color_map.convert('LAB') + + # Split LAB channels + l, a , b = image_lab.split() + _, a_map, b_map = color_map_lab.split() + + # Merge LAB channels with color map + merged_lab = PIL.Image.merge('LAB', (l, a_map, b_map)) + + # Convert merged LAB image back to RGB color space + result_rgb = merged_lab.convert('RGB') + return result_rgb + +def remove_unlikely_words(prompt: str) -> str: + """ + Removes unlikely words from a prompt. + + Args: + prompt: The text prompt to be cleaned. + + Returns: + The cleaned prompt with unlikely words removed. + """ + unlikely_words = [] + + a1_list = [f'{i}s' for i in range(1900, 2000)] + a2_list = [f'{i}' for i in range(1900, 2000)] + a3_list = [f'year {i}' for i in range(1900, 2000)] + a4_list = [f'circa {i}' for i in range(1900, 2000)] + b1_list = [f"{year[0]} {year[1]} {year[2]} {year[3]} s" for year in a1_list] + b2_list = [f"{year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] + b3_list = [f"year {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] + b4_list = [f"circa {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] + + words_list = [ + "black and white,", "black and white", "black & white,", "black & white", "circa", + "balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,", + "black - and - white photography,", "monochrome bw,", "black white,", "black an white,", + "grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo", + "back and white", "back and white,", "monochrome contrast", "monochrome", "grainy", + "grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w", + "grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo", + "b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,", + "black-and-white photo,", "black-and-white photo", "black - and - white photography", + "b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic", + "blurry photo,", "blurry,", "blurry photography,", "monochromatic photo", + "black - and - white photograph,", "black - and - white photograph", "black on white,", + "black on white", "black-and-white", "historical image,", "historical picture,", + "historical photo,", "historical photograph,", "archival photo,", "taken in the early", + "taken in the late", "taken in the", "historic photograph,", "restored,", "restored", + "historical photo", "historical setting,", + "historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated", + "taken in", "shot on leica", "shot on leica sl2", "sl2", + "taken with a leica camera", "taken with a leica camera", "leica sl2", "leica", "setting", + "overcast day", "overcast weather", "slight overcast", "overcast", + "picture taken in", "photo taken in", + ", photo", ", photo", ", photo", ", photo", ", photograph", + ",,", ",,,", ",,,,", " ,", " ,", " ,", " ,", + ] + + unlikely_words.extend(a1_list) + unlikely_words.extend(a2_list) + unlikely_words.extend(a3_list) + unlikely_words.extend(a4_list) + unlikely_words.extend(b1_list) + unlikely_words.extend(b2_list) + unlikely_words.extend(b3_list) + unlikely_words.extend(b4_list) + unlikely_words.extend(words_list) + + for word in unlikely_words: + prompt = prompt.replace(word, "") + return prompt + +def blip_image_captioning(image: PIL.Image.Image, + model_backbone: str, + weight_dtype: type, + device: str, + conditional: bool) -> str: + # https://huggingface.co/Salesforce/blip-image-captioning-large + # https://huggingface.co/Salesforce/blip-image-captioning-base + if weight_dtype == torch.bfloat16: # in case model might not accept bfloat16 data type + weight_dtype = torch.float16 + + processor = BlipProcessor.from_pretrained(f"Salesforce/{model_backbone}") + model = BlipForConditionalGeneration.from_pretrained( + f"Salesforce/{model_backbone}", torch_dtype=weight_dtype).to(device) + + valid_backbones = ["blip-image-captioning-large", "blip-image-captioning-base"] + if model_backbone not in valid_backbones: + raise ValueError(f"Invalid model backbone '{model_backbone}'. \ + Valid options are: {', '.join(valid_backbones)}") + + if conditional: + text = "a photography of" + inputs = processor(image, text, return_tensors="pt").to(device, weight_dtype) + else: + inputs = processor(image, return_tensors="pt").to(device) + out = model.generate(**inputs) + caption = processor.decode(out[0], skip_special_tokens=True) + return caption + +# def vit_gpt2_image_captioning(image: PIL.Image.Image, device: str) -> str: +# # https://huggingface.co/nlpconnect/vit-gpt2-image-captioning +# model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device) +# feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") +# tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") + +# max_length = 16 +# num_beams = 4 +# gen_kwargs = {"max_length": max_length, "num_beams": num_beams} + +# pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values +# pixel_values = pixel_values.to(device) + +# output_ids = model.generate(pixel_values, **gen_kwargs) + +# preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) +# caption = [pred.strip() for pred in preds] + +# return caption[0] + +# def clip_image_captioning(image: PIL.Image.Image, +# clip_model_name: str, +# device: str) -> str: +# # validate clip model name +# models = list_clip_models() +# if clip_model_name not in models: +# raise ValueError(f"Could not find CLIP model {clip_model_name}! \ +# Available models: {models}") +# config = Config(device=device, clip_model_name=clip_model_name) +# config.apply_low_vram_defaults() +# ci = Interrogator(config) +# caption = ci.interrogate(image) +# return caption + +# Define a function to process the image with the loaded model +def process_image(image_path: str, + controlnet_model_name_or_path: str, + caption_model_name: str, + positive_prompt: Optional[str], + negative_prompt: Optional[str], + seed: int, + num_inference_steps: int, + mixed_precision: str, + pretrained_model_name_or_path: str, + pretrained_vae_model_name_or_path: Optional[str], + revision: Optional[str], + variant: Optional[str], + repo: str, + ckpt: str,) -> PIL.Image.Image: + # Seed + generator = torch.manual_seed(seed) + + # Accelerator Setting + accelerator = Accelerator( + mixed_precision=mixed_precision, + ) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + vae_path = ( + pretrained_model_name_or_path + if pretrained_vae_model_name_or_path is None + else pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if pretrained_vae_model_name_or_path is None else None, + revision=revision, + variant=variant, + ) + unet = UNet2DConditionModel.from_config( + pretrained_model_name_or_path, + subfolder="unet", + revision=revision, + variant=variant, + ) + unet.load_state_dict(load_file(hf_hub_download(repo, ckpt))) + + # Move vae, unet and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + if pretrained_vae_model_name_or_path is not None: + vae.to(accelerator.device, dtype=weight_dtype) + else: + vae.to(accelerator.device, dtype=torch.float32) + unet.to(accelerator.device, dtype=weight_dtype) + + controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path, torch_dtype=weight_dtype) + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + pretrained_model_name_or_path, + vae=vae, + unet=unet, + controlnet=controlnet, + ) + pipe.to(accelerator.device, dtype=weight_dtype) + + image = PIL.Image.open(image_path) + + # Prepare everything with our `accelerator`. + pipe, image = accelerator.prepare(pipe, image) + pipe.safety_checker = None + + # Convert image into grayscale + original_size = image.size + control_image = image.convert("L").convert("RGB").resize((512, 512)) + + # Image captioning + if caption_model_name == "blip-image-captioning-large" or "blip-image-captioning-base": + caption = blip_image_captioning(control_image, caption_model_name, + weight_dtype, accelerator.device, conditional=True) + # elif caption_model_name == "ViT-L-14/openai" or "ViT-H-14/laion2b_s32b_b79k": + # caption = clip_image_captioning(control_image, caption_model_name, accelerator.device) + # elif caption_model_name == "vit-gpt2-image-captioning": + # caption = vit_gpt2_image_captioning(control_image, accelerator.device) + caption = remove_unlikely_words(caption) + + # Combine positive prompt and captioning result + prompt = [positive_prompt + ", " + caption] + + # Image colorization + image = pipe(prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=num_inference_steps, + generator=generator, + image=control_image).images[0] + + # Apply color mapping + result_image = apply_color(control_image, image) + result_image = result_image.resize(original_size) + return result_image, caption + +# Define the image gallery based on folder path +def get_image_paths(folder_path): + import os + image_paths = [] + for filename in os.listdir(folder_path): + if filename.endswith(".jpg") or filename.endswith(".png"): + image_paths.append([os.path.join(folder_path, filename)]) + return image_paths + +# Create the Gradio interface +def create_interface(): + controlnet_model_dict = { + "sdxl-light-caption-30000": "sdxl_light_caption_output/checkpoint-30000/controlnet", + "sdxl-light-custom-caption-30000": "sdxl_light_custom_caption_output/checkpoint-30000/controlnet", + } + images = get_image_paths("example/legacy_images") # Replace with your folder path + + interface = gr.Interface( + fn=process_image, + inputs=[ + gr.Image(label="Upload image", + value="example/legacy_images/Hollywood-Sign.jpg", + type='filepath'), + gr.Dropdown(choices=[controlnet_model_dict[key] for key in controlnet_model_dict], + value=controlnet_model_dict["sdxl-light-caption-30000"], + label="Select ControlNet Model"), + gr.Dropdown(choices=["blip-image-captioning-large", + "blip-image-captioning-base",], + value="blip-image-captioning-large", + label="Select Image Captioning Model"), + gr.Textbox(label="Positive Prompt", placeholder="Text for positive prompt"), + gr.Textbox(value="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate", + label="Negative Prompt", placeholder="Text for negative prompt"), + ], + outputs=[ + gr.Image(label="Colorized image", + value="example/UUColor_results/Hollywood-Sign.jpeg", + format="jpeg"), + gr.Textbox(label="Captioning Result", show_copy_button=True) + ], + examples=images, + additional_inputs=[ + # gr.Radio(choices=["Original", "Square"], value="Original", + # label="Output resolution"), + # gr.Slider(minimum=128, maximum=512, value=256, step=128, + # label="Height & Width", + # info='Only effect if select "Square" output resolution'), + gr.Slider(0, 1000, 123, label="Seed"), + gr.Radio(choices=[1, 2, 4, 8], + value=8, + label="Inference Steps", + info="1-step, 2-step, 4-step, or 8-step distilled models"), + gr.Radio(choices=["no", "fp16", "bf16"], + value="fp16", + label="Mixed Precision", + info="Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16)."), + gr.Dropdown(choices=["stabilityai/stable-diffusion-xl-base-1.0"], + value="stabilityai/stable-diffusion-xl-base-1.0", + label="Base Model", + info="Path to pretrained model or model identifier from huggingface.co/models."), + gr.Dropdown(choices=["None"], + value=None, + label="VAE Model", + info="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038."), + gr.Dropdown(choices=["None"], + value=None, + label="Varient", + info="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16"), + gr.Dropdown(choices=["None"], + value=None, + label="Revision", + info="Revision of pretrained model identifier from huggingface.co/models."), + gr.Dropdown(choices=["ByteDance/SDXL-Lightning"], + value="ByteDance/SDXL-Lightning", + label="Repository", + info="Repository from huggingface.co"), + gr.Dropdown(choices=["sdxl_lightning_1step_unet.safetensors", + "sdxl_lightning_2step_unet.safetensors", + "sdxl_lightning_4step_unet.safetensors", + "sdxl_lightning_8step_unet.safetensors"], + value="sdxl_lightning_8step_unet.safetensors", + label="Checkpoint", + info="Available checkpoints from the repository. Caution! Checkpoint's 'N'step must match with inference steps"), + ], + title="Text-Guided Image Colorization", + description="Upload an image and select a model to colorize it." + ) + return interface + +def main(): + # Launch the Gradio interface + interface = create_interface() + interface.launch() + +if __name__ == "__main__": + main() diff --git a/images/000000022935_gray.jpg b/images/000000022935_gray.jpg new file mode 100644 index 0000000000000000000000000000000000000000..63495f225a70a6e2bd99fdc94abba18ee0279f28 Binary files /dev/null and b/images/000000022935_gray.jpg differ diff --git a/images/000000022935_green_shirt_on_right_girl.jpeg b/images/000000022935_green_shirt_on_right_girl.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..c7d459769aafbd3bb1480e72fe2490c825ec2780 Binary files /dev/null and b/images/000000022935_green_shirt_on_right_girl.jpeg differ diff --git a/images/000000022935_purple_shirt_on_right_girl.jpeg b/images/000000022935_purple_shirt_on_right_girl.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..79e7187600bd243ffdce3a2e3aad3ee04fd12f50 Binary files /dev/null and b/images/000000022935_purple_shirt_on_right_girl.jpeg differ diff --git a/images/000000022935_red_shirt_on_right_girl.jpeg b/images/000000022935_red_shirt_on_right_girl.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..d6882c3df5a52f3e5c3549477ad7dd893cca4b66 Binary files /dev/null and b/images/000000022935_red_shirt_on_right_girl.jpeg differ diff --git a/images/000000025560_color.jpg b/images/000000025560_color.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b37bef3b475c6343d2c39b15711b088368b7f25d Binary files /dev/null and b/images/000000025560_color.jpg differ diff --git a/images/000000025560_gray.jpg b/images/000000025560_gray.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8426a5355b12f6fc6a007d1b140294d7672487a9 Binary files /dev/null and b/images/000000025560_gray.jpg differ diff --git a/images/000000025560_gt.jpg b/images/000000025560_gt.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e292e82af4c8a1b754705fb1994806dad261d0fc Binary files /dev/null and b/images/000000025560_gt.jpg differ diff --git a/images/000000041633_black_car.jpeg b/images/000000041633_black_car.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..7f93ab3031f5289f92b3cf1b2e06e9ca91b1c953 Binary files /dev/null and b/images/000000041633_black_car.jpeg differ diff --git a/images/000000041633_bright_red_car.jpeg b/images/000000041633_bright_red_car.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..9b4f2199207294c8ce6c3dd72065039a0c581983 Binary files /dev/null and b/images/000000041633_bright_red_car.jpeg differ diff --git a/images/000000041633_dark_blue_car.jpeg b/images/000000041633_dark_blue_car.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..524a55b5e7c7385d15ce44d38df0bd13200b1181 Binary files /dev/null and b/images/000000041633_dark_blue_car.jpeg differ diff --git a/images/000000041633_gray.jpg b/images/000000041633_gray.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8a8ccba7c9b3371398161154f6009f255b0259b5 Binary files /dev/null and b/images/000000041633_gray.jpg differ diff --git a/images/000000065736_color.jpg b/images/000000065736_color.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cbcdb18e77b79826fab965a6beb1ea99b4272666 Binary files /dev/null and b/images/000000065736_color.jpg differ diff --git a/images/000000065736_gray.jpg b/images/000000065736_gray.jpg new file mode 100644 index 0000000000000000000000000000000000000000..df988f1927baefdb2c9fd61c46afc0acfe0a1550 Binary files /dev/null and b/images/000000065736_gray.jpg differ diff --git a/images/000000065736_gt.jpg b/images/000000065736_gt.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d8b5789010fe5f6df5f5b6905c73d1c0ffae1107 Binary files /dev/null and b/images/000000065736_gt.jpg differ diff --git a/images/000000091779_color.jpg b/images/000000091779_color.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a322031aa0d572d51576650ab71561afdf9645c3 Binary files /dev/null and b/images/000000091779_color.jpg differ diff --git a/images/000000091779_gray.jpg b/images/000000091779_gray.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bf8989c2d9ed2f8b48881c221b338d31b12568e2 Binary files /dev/null and b/images/000000091779_gray.jpg differ diff --git a/images/000000091779_gt.jpg b/images/000000091779_gt.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f8e8ff3cafbf2ed88d6390b58af890dab9065c7d Binary files /dev/null and b/images/000000091779_gt.jpg differ diff --git a/images/000000092177_color.jpg b/images/000000092177_color.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0e274f285ab5c423bb89bcf5f3cbf0383c9c9643 Binary files /dev/null and b/images/000000092177_color.jpg differ diff --git a/images/000000092177_gray.jpg b/images/000000092177_gray.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fb69145958c6e0ec1094663f60f5395fb58e1f38 Binary files /dev/null and b/images/000000092177_gray.jpg differ diff --git a/images/000000092177_gt.jpg b/images/000000092177_gt.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ea2a3798e5216df04e1ac24925b04fa7e27309f1 Binary files /dev/null and b/images/000000092177_gt.jpg differ diff --git a/images/000000166426_color.jpg b/images/000000166426_color.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ece363acb7b42cac29aa29f85f42145ea10998ab Binary files /dev/null and b/images/000000166426_color.jpg differ diff --git a/images/000000166426_gray.jpg b/images/000000166426_gray.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b083a1ae854efe432f1b08558253c126f436072f Binary files /dev/null and b/images/000000166426_gray.jpg differ diff --git a/images/000000166426_gt.jpg b/images/000000166426_gt.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2d4ebd1efc6fb375fc46c5b36936411de4991f1b Binary files /dev/null and b/images/000000166426_gt.jpg differ diff --git a/images/000000286708_gray.jpg b/images/000000286708_gray.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2271b6c13630ac2a927576d2d5d45bb333885748 Binary files /dev/null and b/images/000000286708_gray.jpg differ diff --git a/images/000000286708_orange_hat.jpeg b/images/000000286708_orange_hat.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..d9db3a27409800e9cd817fc25106dfd8cb37e8c5 Binary files /dev/null and b/images/000000286708_orange_hat.jpeg differ diff --git a/images/000000286708_pink_hat.jpeg b/images/000000286708_pink_hat.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..9deaa638ed84e8cc6af6036faf76fc0ab852e3df Binary files /dev/null and b/images/000000286708_pink_hat.jpeg differ diff --git a/images/000000286708_yellow_hat.jpeg b/images/000000286708_yellow_hat.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..505ce55fdb784cf8bd82568c1e659335bd3efc86 Binary files /dev/null and b/images/000000286708_yellow_hat.jpeg differ diff --git a/images/gradio_ui.png b/images/gradio_ui.png new file mode 100644 index 0000000000000000000000000000000000000000..ea7a7d2d8483bbdea53cd5aec3b139cf2ed07e05 Binary files /dev/null and b/images/gradio_ui.png differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7984eed36f9baf68a4a68b1780e7565c0f933720 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +accelerate>=0.16.0 +# torch==1.13.1+cu117 +# torchvision==0.14.1+cu117 +transformers>=4.25.1 +ftfy +tensorboard +datasets +bitsandbytes +git+https://github.com/huggingface/diffusers diff --git a/train_controlnet.py b/train_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6074bc961577ff5277197f70f8484fdd57f06ceb --- /dev/null +++ b/train_controlnet.py @@ -0,0 +1,1193 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import contextlib +import gc +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + ControlNetModel, + DDPMScheduler, + StableDiffusionControlNetPipeline, + UNet2DConditionModel, + UniPCMultistepScheduler, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.28.0.dev0") + +logger = get_logger(__name__) + + +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + +def log_validation( + vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False +): + logger.info("Running validation... ") + + if not is_final_validation: + controlnet = accelerator.unwrap_model(controlnet) + else: + controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype) + + pipeline = StableDiffusionControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + safety_checker=None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + if len(args.validation_image) == len(args.validation_prompt): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + elif len(args.validation_image) == 1: + validation_images = args.validation_image * len(args.validation_prompt) + validation_prompts = args.validation_prompt + elif len(args.validation_prompt) == 1: + validation_images = args.validation_image + validation_prompts = args.validation_prompt * len(args.validation_image) + else: + raise ValueError( + "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" + ) + + image_logs = [] + inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda") + + for validation_prompt, validation_image in zip(validation_prompts, validation_images): + validation_image = Image.open(validation_image).convert("RGB").resize((512, 512)) # resize to prevent size mismatch when stacking + + images = [] + + for _ in range(args.num_validation_images): + with inference_ctx: + image = pipeline( + validation_prompt, validation_image, num_inference_steps=20, generator=generator + ).images[0] + + images.append(image) + + image_logs.append( + {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} + ) + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images = [] + + formatted_images.append(np.asarray(validation_image)) + + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({tracker_key: formatted_images}) + else: + logger.warning(f"image logging not implemented for {tracker.name}") + + del pipeline + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): + img_str = "" + if image_logs is not None: + img_str = "You can find some example images below.\n\n" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + model_description = f""" +# controlnet-{repo_id} + +These are controlnet weights trained on {base_model} with new type of conditioning. +{img_str} +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="creativeml-openrail-m", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = [ + "stable-diffusion", + "stable-diffusion-diffusers", + "text-to-image", + "diffusers", + "controlnet", + "diffusers-training", + ] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained controlnet model or model identifier from huggingface.co/models." + " If not specified controlnet weights are initialized from unet.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--output_dir", + type=str, + default="controlnet-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--dataset_revision", + type=str, + default='main', + help="The revision of the Dataset, leave as 'main' by default.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="train_controlnet", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") + + if args.dataset_name is not None and args.train_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + ) + + return args + + +def make_train_dataset(args, tokenizer, accelerator): + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + revision=args.dataset_revision, + cache_dir=args.cache_dir, + ) + else: + if args.train_data_dir is not None: + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = column_names[1] + logger.info(f"caption column defaulting to {caption_column}") + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.conditioning_image_column is None: + conditioning_image_column = column_names[2] + logger.info(f"conditioning image column defaulting to {conditioning_image_column}") + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + raise ValueError( + f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if random.random() < args.proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + conditioning_image_transforms = transforms.Compose( + [ + transforms.Grayscale(num_output_channels=3), # convert to grayscale image + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + ] + ) + + def preprocess_train(examples): + images = [Image.open(image).convert("RGB") for image in examples[image_column]] + images = [image_transforms(image) for image in images] + + conditioning_images = [Image.open(image).convert("RGB") for image in examples[conditioning_image_column]] + conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] + + examples["pixel_values"] = images + examples["conditioning_pixel_values"] = conditioning_images + examples["input_ids"] = tokenize_captions(examples) + + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + return train_dataset + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.stack([example["input_ids"] for example in examples]) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "input_ids": input_ids, + } + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + + if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnet weights") + controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path) + else: + logger.info("Initializing controlnet weights from unet") + controlnet = ControlNetModel.from_unet(unet) + + # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + i = len(weights) - 1 + + while len(weights) > 0: + weights.pop() + model = models[i] + + sub_dir = "controlnet" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + i -= 1 + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + vae.requires_grad_(False) + unet.requires_grad_(False) + text_encoder.requires_grad_(False) + controlnet.train() + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warning( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + controlnet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + controlnet.enable_gradient_checkpointing() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if unwrap_model(controlnet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = controlnet.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + train_dataset = make_train_dataset(args, tokenizer, accelerator) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae, unet and text_encoder to device and cast to weight_dtype + vae.to(accelerator.device, dtype=weight_dtype) + unet.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + + # tensorboard cannot handle list types for config + tracker_config.pop("validation_prompt") + tracker_config.pop("validation_image") + + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + image_logs = None + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(controlnet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] + + controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) + + down_block_res_samples, mid_block_res_sample = controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_image, + return_dict=False, + ) + + # Predict the noise residual + model_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states=encoder_hidden_states, + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), + return_dict=False, + )[0] + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + image_logs = log_validation( + vae, + text_encoder, + tokenizer, + unet, + controlnet, + args, + accelerator, + weight_dtype, + global_step, + ) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + controlnet = unwrap_model(controlnet) + controlnet.save_pretrained(args.output_dir) + + # Run a final round of validation. + image_logs = None + if args.validation_prompt is not None: + image_logs = log_validation( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=None, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + is_final_validation=True, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/train_controlnet.sh b/train_controlnet.sh new file mode 100644 index 0000000000000000000000000000000000000000..3ce2a90adfd0beeaeb02fafeb923052d5d0c9855 --- /dev/null +++ b/train_controlnet.sh @@ -0,0 +1,40 @@ +# Original ControlNet paper: +# "In the training process, we randomly replace 50% text prompts ct with empty strings. +# This approach increases ControlNet’s ability to directly recognize semantics +# in the input conditioning images (e.g., edges, poses, depth, etc.) as a replacement for the prompt." +# https://civitai.com/articles/2078/play-in-control-controlnet-training-setup-guide + +# export MODEL_DIR="runwayml/stable-diffusion-v1-5" +export MODEL_DIR="stabilityai/stable-diffusion-2-base" +export OUTPUT_DIR="sd_v2_caption_kl_output" +export DATASET="nickpai/coco2017-colorization" +export REVISION="main" # option: main/caption-free +export VAL_IMG_NAME="'./000000295478.jpg' './000000122962.jpg' './000000000285.jpg' './000000007991.jpg' './000000018837.jpg' './000000000724.jpg'" +export VAL_PROMPT="'Woman walking a small dog behind her.' 'A group of children sitting at a long table eating pizza.' 'A close up picture of a bear face.' 'A plate on a table is filled with carrots and beans.' 'A large truck on a city street with two works sitting on top and one worker climbing in through door.' 'An upside down stop sign by the road.'" +# export VAL_PROMPT="'Colorize this image as if it was taken with a color camera' 'Colorize this image' 'Add colors to this image' 'Make this image colorful' 'Colorize this grayscale image' 'Add colors to this image'" + +accelerate launch train_controlnet.py \ + --pretrained_model_name_or_path=$MODEL_DIR \ + --output_dir=$OUTPUT_DIR \ + --seed=123123 \ + --dataset_name=$DATASET \ + --dataset_revision=$REVISION \ + --image_column="file_name" \ + --conditioning_image_column="file_name" \ + --caption_column="captions" \ + --max_train_samples=100000 \ + --num_validation_images=1 \ + --resolution=512 \ + --num_train_epochs=5 \ + --dataloader_num_workers=8 \ + --learning_rate=1e-5 \ + --validation_image './000000295478.jpg' './000000122962.jpg' './000000000285.jpg' './000000007991.jpg' './000000018837.jpg' './000000000724.jpg' \ + --validation_prompt 'Woman walking a small dog behind her.' 'A group of children sitting at a long table eating pizza.' 'A close up picture of a bear face.' 'A plate on a table is filled with carrots and beans.' 'A large truck on a city street with two works sitting on top and one worker climbing in through door.' 'An upside down stop sign by the road.' \ + --train_batch_size=2 \ + --gradient_accumulation_steps=8 \ + --proportion_empty_prompts=0 \ + --validation_steps=500 \ + --checkpointing_steps=2500 \ + --mixed_precision="fp16" \ + --gradient_checkpointing \ + --use_8bit_adam \ No newline at end of file diff --git a/train_controlnet_sdxl.py b/train_controlnet_sdxl.py new file mode 100644 index 0000000000000000000000000000000000000000..e06fdbe2499d876aebb34b1f10a9d580ebcea363 --- /dev/null +++ b/train_controlnet_sdxl.py @@ -0,0 +1,1330 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import functools +import gc +import logging +import math +import os +import random +import shutil +from contextlib import nullcontext +from pathlib import Path + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + ControlNetModel, + DDPMScheduler, + StableDiffusionXLControlNetPipeline, + UNet2DConditionModel, + UniPCMultistepScheduler, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.28.0.dev0") + +logger = get_logger(__name__) + + +def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False): + logger.info("Running validation... ") + + if not is_final_validation: + controlnet = accelerator.unwrap_model(controlnet) + pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + unet=unet, + controlnet=controlnet, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + else: + controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype) + if args.pretrained_vae_model_name_or_path is not None: + vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_name_or_path, torch_dtype=weight_dtype) + else: + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype + ) + + pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + controlnet=controlnet, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + if len(args.validation_image) == len(args.validation_prompt): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + elif len(args.validation_image) == 1: + validation_images = args.validation_image * len(args.validation_prompt) + validation_prompts = args.validation_prompt + elif len(args.validation_prompt) == 1: + validation_images = args.validation_image + validation_prompts = args.validation_prompt * len(args.validation_image) + else: + raise ValueError( + "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" + ) + + image_logs = [] + if is_final_validation or torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + for validation_prompt, validation_image in zip(validation_prompts, validation_images): + validation_image = Image.open(validation_image).convert("RGB") + validation_image = validation_image.resize((args.resolution, args.resolution)) + + images = [] + + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator + ).images[0] + images.append(image) + + image_logs.append( + {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} + ) + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images = [] + + formatted_images.append(np.asarray(validation_image)) + + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({tracker_key: formatted_images}) + else: + logger.warning(f"image logging not implemented for {tracker.name}") + + del pipeline + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): + img_str = "" + if image_logs is not None: + img_str = "You can find some example images below.\n\n" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + model_description = f""" +# controlnet-{repo_id} + +These are controlnet weights trained on {base_model} with new type of conditioning. +{img_str} +""" + + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="openrail++", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = [ + "stable-diffusion-xl", + "stable-diffusion-xl-diffusers", + "text-to-image", + "diffusers", + "controlnet", + "diffusers-training", + ] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained controlnet model or model identifier from huggingface.co/models." + " If not specified controlnet weights are initialized from unet.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--output_dir", + type=str, + default="controlnet-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--crops_coords_top_left_h", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--crops_coords_top_left_w", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--dataset_revision", + type=str, + default='main', + help="The revision of the Dataset, leave as 'main' by default.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="sd_xl_train_controlnet", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") + + if args.dataset_name is not None and args.train_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + ) + + return args + + +def get_train_dataset(args, accelerator): + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + revision=args.dataset_revision, + cache_dir=args.cache_dir, + ) + else: + if args.train_data_dir is not None: + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = column_names[1] + logger.info(f"caption column defaulting to {caption_column}") + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.conditioning_image_column is None: + conditioning_image_column = column_names[2] + logger.info(f"conditioning image column defaulting to {conditioning_image_column}") + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + raise ValueError( + f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + with accelerator.main_process_first(): + train_dataset = dataset["train"].shuffle(seed=args.seed) + if args.max_train_samples is not None: + train_dataset = train_dataset.select(range(args.max_train_samples)) + return train_dataset + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True): + prompt_embeds_list = [] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +def prepare_train_dataset(dataset, accelerator): + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + conditioning_image_transforms = transforms.Compose( + [ + transforms.Grayscale(num_output_channels=3), # convert to grayscale image + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + ] + ) + + def preprocess_train(examples): + images = [Image.open(image).convert("RGB") for image in examples[args.image_column]] + images = [image_transforms(image) for image in images] + + conditioning_images = [Image.open(image).convert("RGB") for image in examples[args.conditioning_image_column]] + conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] + + examples["pixel_values"] = images + examples["conditioning_pixel_values"] = conditioning_images + + return examples + + with accelerator.main_process_first(): + dataset = dataset.with_transform(preprocess_train) + + return dataset + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + + prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]) + + add_text_embeds = torch.stack([torch.tensor(example["text_embeds"]) for example in examples]) + add_time_ids = torch.stack([torch.tensor(example["time_ids"]) for example in examples]) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "prompt_ids": prompt_ids, + "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids}, + } + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizers + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + + if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnet weights") + controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path) + else: + logger.info("Initializing controlnet weights from unet") + controlnet = ControlNetModel.from_unet(unet) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + i = len(weights) - 1 + + while len(weights) > 0: + weights.pop() + model = models[i] + + sub_dir = "controlnet" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + i -= 1 + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + vae.requires_grad_(False) + unet.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + controlnet.train() + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warning( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + controlnet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + controlnet.enable_gradient_checkpointing() + unet.enable_gradient_checkpointing() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if unwrap_model(controlnet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = controlnet.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae, unet and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + if args.pretrained_vae_model_name_or_path is not None: + vae.to(accelerator.device, dtype=weight_dtype) + else: + vae.to(accelerator.device, dtype=torch.float32) + unet.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + # Here, we compute not just the text embeddings but also the additional embeddings + # needed for the SD XL UNet to operate. + def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True): + original_size = (args.resolution, args.resolution) + target_size = (args.resolution, args.resolution) + crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w) + prompt_batch = batch[args.caption_column] + + prompt_embeds, pooled_prompt_embeds = encode_prompt( + prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train + ) + add_text_embeds = pooled_prompt_embeds + + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + + prompt_embeds = prompt_embeds.to(accelerator.device) + add_text_embeds = add_text_embeds.to(accelerator.device) + add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) + add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) + unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} + + # Let's first compute all the embeddings so that we can free up the text encoders + # from memory. + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + train_dataset = get_train_dataset(args, accelerator) + compute_embeddings_fn = functools.partial( + compute_embeddings, + text_encoders=text_encoders, + tokenizers=tokenizers, + proportion_empty_prompts=args.proportion_empty_prompts, + ) + with accelerator.main_process_first(): + from datasets.fingerprint import Hasher + + # fingerprint used by the cache for the other processes to load the result + # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 + new_fingerprint = Hasher.hash(args) + train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) + + del text_encoders, tokenizers + gc.collect() + torch.cuda.empty_cache() + + # Then get the training dataset ready to be passed to the dataloader. + train_dataset = prepare_train_dataset(train_dataset, accelerator) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + + # tensorboard cannot handle list types for config + tracker_config.pop("validation_prompt") + tracker_config.pop("validation_image") + + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + image_logs = None + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(controlnet): + # Convert images to latent space + if args.pretrained_vae_model_name_or_path is not None: + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + else: + pixel_values = batch["pixel_values"] + latents = vae.encode(pixel_values).latent_dist.sample() + latents = latents * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + latents = latents.to(weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # ControlNet conditioning. + controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) + down_block_res_samples, mid_block_res_sample = controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=batch["prompt_ids"], + added_cond_kwargs=batch["unet_added_conditions"], + controlnet_cond=controlnet_image, + return_dict=False, + ) + + # Predict the noise residual + model_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states=batch["prompt_ids"], + added_cond_kwargs=batch["unet_added_conditions"], + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), + return_dict=False, + )[0] + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + image_logs = log_validation( + vae=vae, + unet=unet, + controlnet=controlnet, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + ) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + controlnet = unwrap_model(controlnet) + controlnet.save_pretrained(args.output_dir) + + # Run a final round of validation. + # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`. + image_logs = None + if args.validation_prompt is not None: + image_logs = log_validation( + vae=None, + unet=None, + controlnet=None, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + is_final_validation=True, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/train_controlnet_sdxl.sh b/train_controlnet_sdxl.sh new file mode 100644 index 0000000000000000000000000000000000000000..d2a51fa4ef4dd1807671032b1726c466953450b5 --- /dev/null +++ b/train_controlnet_sdxl.sh @@ -0,0 +1,42 @@ +# Original ControlNet paper: +# "In the training process, we randomly replace 50% text prompts ct with empty strings. +# This approach increases ControlNet’s ability to directly recognize semantics +# in the input conditioning images (e.g., edges, poses, depth, etc.) as a replacement for the prompt." +# https://civitai.com/articles/2078/play-in-control-controlnet-training-setup-guide + +export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0" +export OUTPUT_DIR="sdxl_caption_output" +export DATASET="nickpai/coco2017-colorization" +export REVISION="main" # option: main/caption-free +export VAL_IMG_NAME="'./000000295478.jpg' './000000122962.jpg' './000000000285.jpg' './000000007991.jpg' './000000018837.jpg' './000000000724.jpg'" +export VAL_PROMPT="'Woman walking a small dog behind her.' 'A group of children sitting at a long table eating pizza.' 'A close up picture of a bear face.' 'A plate on a table is filled with carrots and beans.' 'A large truck on a city street with two works sitting on top and one worker climbing in through door.' 'An upside down stop sign by the road.'" +# export VAL_PROMPT="'Colorize this image as if it was taken with a color camera' 'Colorize this image' 'Add colors to this image' 'Make this image colorful' 'Colorize this grayscale image' 'Add colors to this image'" + +accelerate launch train_controlnet_sdxl.py \ + --pretrained_model_name_or_path=$MODEL_DIR \ + --output_dir=$OUTPUT_DIR \ + --seed=123123 \ + --dataset_name=$DATASET \ + --dataset_revision=$REVISION \ + --image_column="file_name" \ + --conditioning_image_column="file_name" \ + --caption_column="captions" \ + --max_train_samples=100000 \ + --num_validation_images=1 \ + --resolution=512 \ + --num_train_epochs=5 \ + --dataloader_num_workers=8 \ + --learning_rate=1e-5 \ + --train_batch_size=2 \ + --gradient_accumulation_steps=8 \ + --proportion_empty_prompts=0 \ + --validation_steps=500 \ + --checkpointing_steps=2500 \ + --mixed_precision="fp16" \ + --gradient_checkpointing \ + --use_8bit_adam \ + --enable_xformers_memory_efficient_attention + +# --validation_image './000000295478.jpg' './000000122962.jpg' './000000000285.jpg' './000000007991.jpg' './000000018837.jpg' './000000000724.jpg' \ +# --validation_prompt 'Woman walking a small dog behind her.' 'A group of children sitting at a long table eating pizza.' 'A close up picture of a bear face.' 'A plate on a table is filled with carrots and beans.' 'A large truck on a city street with two works sitting on top and one worker climbing in through door.' 'An upside down stop sign by the road.' \ + \ No newline at end of file diff --git a/train_controlnet_sdxl_light.py b/train_controlnet_sdxl_light.py new file mode 100644 index 0000000000000000000000000000000000000000..426004e56d0e73231b67b3b034ccddf8cb54e037 --- /dev/null +++ b/train_controlnet_sdxl_light.py @@ -0,0 +1,1359 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import functools +import gc +import logging +import math +import os +import random +import shutil +from contextlib import nullcontext +from pathlib import Path + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator, cpu_offload +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + ControlNetModel, + DDPMScheduler, + StableDiffusionXLControlNetPipeline, + UNet2DConditionModel, + UniPCMultistepScheduler, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.28.0.dev0") + +logger = get_logger(__name__) + + +def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False): + logger.info("Running validation... ") + + if not is_final_validation: + controlnet = accelerator.unwrap_model(controlnet) + pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + unet=unet, + controlnet=controlnet, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + else: + controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype) + if args.pretrained_vae_model_name_or_path is not None: + vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_name_or_path, torch_dtype=weight_dtype) + else: + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype + ) + + pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + controlnet=controlnet, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + if len(args.validation_image) == len(args.validation_prompt): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + elif len(args.validation_image) == 1: + validation_images = args.validation_image * len(args.validation_prompt) + validation_prompts = args.validation_prompt + elif len(args.validation_prompt) == 1: + validation_images = args.validation_image + validation_prompts = args.validation_prompt * len(args.validation_image) + else: + raise ValueError( + "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" + ) + + image_logs = [] + if is_final_validation or torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + for validation_prompt, validation_image in zip(validation_prompts, validation_images): + validation_image = Image.open(validation_image).convert("RGB") + validation_image = validation_image.resize((args.resolution, args.resolution)) + + images = [] + + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt=validation_prompt, image=validation_image, num_inference_steps=args.num_inference_steps, generator=generator + ).images[0] + images.append(image) + + image_logs.append( + {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} + ) + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images = [] + + formatted_images.append(np.asarray(validation_image)) + + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({tracker_key: formatted_images}) + else: + logger.warning(f"image logging not implemented for {tracker.name}") + + del pipeline + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): + img_str = "" + if image_logs is not None: + img_str = "You can find some example images below.\n\n" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + model_description = f""" +# controlnet-{repo_id} + +These are controlnet weights trained on {base_model} with new type of conditioning. +{img_str} +""" + + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="openrail++", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = [ + "stable-diffusion-xl", + "stable-diffusion-xl-diffusers", + "text-to-image", + "diffusers", + "controlnet", + "diffusers-training", + ] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained controlnet model or model identifier from huggingface.co/models." + " If not specified controlnet weights are initialized from unet.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--output_dir", + type=str, + default="controlnet-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--crops_coords_top_left_h", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--crops_coords_top_left_w", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--dataset_revision", + type=str, + default='main', + help="The revision of the Dataset, leave as 'main' by default.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="sd_xl_train_controlnet", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=4, + help="1-step, 2-step, 4-step, or 8-step distilled models" + ) + parser.add_argument( + "--repo", + type=str, + default="ByteDance/SDXL-Lightning", + required=True, + help="Repository from huggingface.co", + ) + parser.add_argument( + "--ckpt", + type=str, + default="sdxl_lightning_4step_unet.safetensors", + required=True, + help="Available checkpoints from the repository", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") + + if args.dataset_name is not None and args.train_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + ) + + return args + + +def get_train_dataset(args, accelerator): + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + revision=args.dataset_revision, + cache_dir=args.cache_dir, + ) + else: + if args.train_data_dir is not None: + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = column_names[1] + logger.info(f"caption column defaulting to {caption_column}") + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.conditioning_image_column is None: + conditioning_image_column = column_names[2] + logger.info(f"conditioning image column defaulting to {conditioning_image_column}") + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + raise ValueError( + f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + with accelerator.main_process_first(): + train_dataset = dataset["train"].shuffle(seed=args.seed) + if args.max_train_samples is not None: + train_dataset = train_dataset.select(range(args.max_train_samples)) + return train_dataset + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True): + prompt_embeds_list = [] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +def prepare_train_dataset(dataset, accelerator): + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + conditioning_image_transforms = transforms.Compose( + [ + transforms.Grayscale(num_output_channels=3), # convert to grayscale image + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + ] + ) + + def preprocess_train(examples): + images = [Image.open(image).convert("RGB") for image in examples[args.image_column]] + images = [image_transforms(image) for image in images] + + conditioning_images = [Image.open(image).convert("RGB") for image in examples[args.conditioning_image_column]] + conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] + + examples["pixel_values"] = images + examples["conditioning_pixel_values"] = conditioning_images + + return examples + + with accelerator.main_process_first(): + dataset = dataset.with_transform(preprocess_train) + + return dataset + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + + prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]) + + add_text_embeds = torch.stack([torch.tensor(example["text_embeds"]) for example in examples]) + add_time_ids = torch.stack([torch.tensor(example["time_ids"]) for example in examples]) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "prompt_ids": prompt_ids, + "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids}, + } + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizers + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + ) + print(vae) + input() + # unet = UNet2DConditionModel.from_pretrained( + # args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + # ) + unet = UNet2DConditionModel.from_config( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + unet.load_state_dict(load_file(hf_hub_download(args.repo, args.ckpt))) + + if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnet weights") + controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path) + else: + logger.info("Initializing controlnet weights from unet") + controlnet = ControlNetModel.from_unet(unet) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + i = len(weights) - 1 + + while len(weights) > 0: + weights.pop() + model = models[i] + + sub_dir = "controlnet" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + i -= 1 + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + vae.requires_grad_(False) + unet.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + controlnet.train() + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warning( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + controlnet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + controlnet.enable_gradient_checkpointing() + unet.enable_gradient_checkpointing() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if unwrap_model(controlnet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = controlnet.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae, unet and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + if args.pretrained_vae_model_name_or_path is not None: + vae.to(accelerator.device, dtype=weight_dtype) + else: + vae.to(accelerator.device, dtype=torch.float32) + unet.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + # Here, we compute not just the text embeddings but also the additional embeddings + # needed for the SD XL UNet to operate. + def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True): + original_size = (args.resolution, args.resolution) + target_size = (args.resolution, args.resolution) + crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w) + prompt_batch = batch[args.caption_column] + + prompt_embeds, pooled_prompt_embeds = encode_prompt( + prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train + ) + add_text_embeds = pooled_prompt_embeds + + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + + prompt_embeds = prompt_embeds.to(accelerator.device) + add_text_embeds = add_text_embeds.to(accelerator.device) + add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) + add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) + unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} + + # Let's first compute all the embeddings so that we can free up the text encoders + # from memory. + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + train_dataset = get_train_dataset(args, accelerator) + compute_embeddings_fn = functools.partial( + compute_embeddings, + text_encoders=text_encoders, + tokenizers=tokenizers, + proportion_empty_prompts=args.proportion_empty_prompts, + ) + with accelerator.main_process_first(): + from datasets.fingerprint import Hasher + + # fingerprint used by the cache for the other processes to load the result + # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 + new_fingerprint = Hasher.hash(args) + train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) + + del text_encoders, tokenizers + gc.collect() + torch.cuda.empty_cache() + + # Then get the training dataset ready to be passed to the dataloader. + train_dataset = prepare_train_dataset(train_dataset, accelerator) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + + # tensorboard cannot handle list types for config + tracker_config.pop("validation_prompt") + tracker_config.pop("validation_image") + + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + image_logs = None + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(controlnet): + # Convert images to latent space + if args.pretrained_vae_model_name_or_path is not None: + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + else: + pixel_values = batch["pixel_values"] + latents = vae.encode(pixel_values).latent_dist.sample() + latents = latents * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + latents = latents.to(weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # ControlNet conditioning. + controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) + down_block_res_samples, mid_block_res_sample = controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=batch["prompt_ids"], + added_cond_kwargs=batch["unet_added_conditions"], + controlnet_cond=controlnet_image, + return_dict=False, + ) + + # Predict the noise residual + model_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states=batch["prompt_ids"], + added_cond_kwargs=batch["unet_added_conditions"], + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), + return_dict=False, + )[0] + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + image_logs = log_validation( + vae=vae, + unet=unet, + controlnet=controlnet, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + ) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + controlnet = unwrap_model(controlnet) + controlnet.save_pretrained(args.output_dir) + + # Run a final round of validation. + # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`. + image_logs = None + if args.validation_prompt is not None: + image_logs = log_validation( + vae=None, + unet=None, + controlnet=None, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + is_final_validation=True, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/train_controlnet_sdxl_light.sh b/train_controlnet_sdxl_light.sh new file mode 100644 index 0000000000000000000000000000000000000000..c880456e59eb02ba0097925a4c421ea77159285c --- /dev/null +++ b/train_controlnet_sdxl_light.sh @@ -0,0 +1,50 @@ +# Original ControlNet paper: +# "In the training process, we randomly replace 50% text prompts ct with empty strings. +# This approach increases ControlNet’s ability to directly recognize semantics +# in the input conditioning images (e.g., edges, poses, depth, etc.) as a replacement for the prompt." +# https://civitai.com/articles/2078/play-in-control-controlnet-training-setup-guide + +export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0" +export REPO="ByteDance/SDXL-Lightning" +export INFERENCE_STEP=8 +export CKPT="sdxl_lightning_8step_unet.safetensors" # caution!!! ckpt's "N"step must match with inference_step +export OUTPUT_DIR="test" +export PROJECT_NAME="train_sdxl_light_controlnet" +export DATASET="nickpai/coco2017-colorization" +export REVISION="custom-caption" # option: main/caption-free/custom-caption +export VAL_IMG_NAME="'./000000295478.jpg' './000000122962.jpg' './000000000285.jpg' './000000007991.jpg' './000000018837.jpg' './000000000724.jpg'" +export VAL_PROMPT="'Woman walking a small dog behind her.' 'A group of children sitting at a long table eating pizza.' 'A close up picture of a bear face.' 'A plate on a table is filled with carrots and beans.' 'A large truck on a city street with two works sitting on top and one worker climbing in through door.' 'An upside down stop sign by the road.'" +# export VAL_PROMPT="'Colorize this image as if it was taken with a color camera' 'Colorize this image' 'Add colors to this image' 'Make this image colorful' 'Colorize this grayscale image' 'Add colors to this image'" + +accelerate launch train_controlnet_sdxl_light.py \ + --pretrained_model_name_or_path=$MODEL_DIR \ + --output_dir=$OUTPUT_DIR \ + --tracker_project_name=$PROJECT_NAME \ + --seed=123123 \ + --dataset_name=$DATASET \ + --dataset_revision=$REVISION \ + --image_column="file_name" \ + --conditioning_image_column="file_name" \ + --caption_column="captions" \ + --max_train_samples=100000 \ + --num_validation_images=1 \ + --resolution=512 \ + --num_train_epochs=5 \ + --dataloader_num_workers=8 \ + --learning_rate=1e-5 \ + --train_batch_size=2 \ + --gradient_accumulation_steps=8 \ + --proportion_empty_prompts=0 \ + --validation_steps=500 \ + --checkpointing_steps=2500 \ + --mixed_precision="fp16" \ + --gradient_checkpointing \ + --use_8bit_adam \ + --repo=$REPO \ + --ckpt=$CKPT \ + --num_inference_steps=$INFERENCE_STEP \ + --enable_xformers_memory_efficient_attention + +# --validation_image './000000295478.jpg' './000000122962.jpg' './000000000285.jpg' './000000007991.jpg' './000000018837.jpg' './000000000724.jpg' \ +# --validation_prompt 'Woman walking a small dog behind her.' 'A group of children sitting at a long table eating pizza.' 'A close up picture of a bear face.' 'A plate on a table is filled with carrots and beans.' 'A large truck on a city street with two works sitting on top and one worker climbing in through door.' 'An upside down stop sign by the road.' \ + \ No newline at end of file