ColorRevive / eval_controlnet_sdxl_light.py
fffiloni's picture
Migrated from GitHub
691af46 verified
import os
import time
import torch
import shutil
import argparse
import numpy as np
from tqdm import tqdm
from PIL import Image
from datasets import load_dataset
from accelerate import Accelerator
from diffusers.utils import load_image
from diffusers import (
AutoencoderKL,
StableDiffusionXLControlNetPipeline,
ControlNetModel,
UNet2DConditionModel,
)
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# Define the function to parse arguments
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a ControlNet evaluation script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_vae_model_name_or_path",
type=str,
default=None,
help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
)
parser.add_argument(
"--controlnet_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained controlnet model.",
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
required=True,
help="Path to output results.",
)
parser.add_argument(
"--dataset",
type=str,
default="nickpai/coco2017-colorization",
help="Dataset used"
)
parser.add_argument(
"--dataset_revision",
type=str,
default="caption-free",
choices=["main", "caption-free", "custom-caption"],
help="Revision option (main/caption-free/custom-caption)"
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--num_inference_steps",
type=int,
default=8,
help="1-step, 2-step, 4-step, or 8-step distilled models"
)
parser.add_argument(
"--repo",
type=str,
default="ByteDance/SDXL-Lightning",
required=True,
help="Repository from huggingface.co",
)
parser.add_argument(
"--ckpt",
type=str,
default="sdxl_lightning_4step_unet.safetensors",
required=True,
help="Available checkpoints from the repository",
)
parser.add_argument(
"--negative_prompt",
action="store_true",
help="The prompt or prompts not to guide the image generation",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
return args
def apply_color(image, color_map):
# Convert input images to LAB color space
image_lab = image.convert('LAB')
color_map_lab = color_map.convert('LAB')
# Split LAB channels
l, a, b = image_lab.split()
_, a_map, b_map = color_map_lab.split()
# Merge LAB channels with color map
merged_lab = Image.merge('LAB', (l, a_map, b_map))
# Convert merged LAB image back to RGB color space
result_rgb = merged_lab.convert('RGB')
return result_rgb
def main(args):
generator = torch.manual_seed(0)
# Path to the eval_results folder
eval_results_folder = os.path.join(args.output_dir, "results")
# Remove eval_results folder if it exists
if os.path.exists(eval_results_folder):
shutil.rmtree(eval_results_folder)
# Create directory for eval_results
os.makedirs(eval_results_folder)
# Create subfolders for compare and colorized images
compare_folder = os.path.join(eval_results_folder, "compare")
colorized_folder = os.path.join(eval_results_folder, "colorized")
os.makedirs(compare_folder)
os.makedirs(colorized_folder)
# Load the validation split of the colorization dataset
val_dataset = load_dataset(args.dataset, split="validation", revision=args.dataset_revision)
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
vae_path = (
args.pretrained_model_name_or_path
if args.pretrained_vae_model_name_or_path is None
else args.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
)
unet = UNet2DConditionModel.from_config(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
variant=args.variant,
)
unet.load_state_dict(load_file(hf_hub_download(args.repo, args.ckpt)))
# Move vae, unet and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses.
if args.pretrained_vae_model_name_or_path is not None:
vae.to(accelerator.device, dtype=weight_dtype)
else:
vae.to(accelerator.device, dtype=torch.float32)
unet.to(accelerator.device, dtype=weight_dtype)
controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, torch_dtype=weight_dtype)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
unet=unet,
controlnet=controlnet,
)
pipe.to(accelerator.device, dtype=weight_dtype)
# Prepare everything with our `accelerator`.
pipe, val_dataset = accelerator.prepare(pipe, val_dataset)
pipe.safety_checker = None
# Counter for processed images
processed_images = 0
# Record start time
start_time = time.time()
# Iterate through the validation dataset
for example in tqdm(val_dataset, desc="Processing Images"):
image_path = example["file_name"]
prompt = []
for caption in example["captions"]:
if isinstance(caption, str):
prompt.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
prompt.append(caption[0])
else:
raise ValueError(
f"Caption column `captions` should contain either strings or lists of strings."
)
negative_prompt = None
if args.negative_prompt:
negative_prompt = [
"low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate"
]
# Generate image
ground_truth_image = load_image(image_path).resize((512, 512))
control_image = load_image(image_path).convert("L").convert("RGB").resize((512, 512))
image = pipe(prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=args.num_inference_steps,
generator=generator,
image=control_image).images[0]
# Apply color mapping
image = apply_color(ground_truth_image, image)
# Concatenate images into a row
row_image = np.hstack((np.array(control_image), np.array(image), np.array(ground_truth_image)))
row_image = Image.fromarray(row_image)
# Save row image in the compare folder
compare_output_path = os.path.join(compare_folder, f"{image_path.split('/')[-1]}")
row_image.save(compare_output_path)
# Save colorized image in the colorized folder
colorized_output_path = os.path.join(colorized_folder, f"{image_path.split('/')[-1]}")
image.save(colorized_output_path)
# Increment processed images counter
processed_images += 1
# Record end time
end_time = time.time()
# Calculate total time taken
total_time = end_time - start_time
# Calculate FPS
fps = processed_images / total_time
print("All images processed.")
print(f"Total time taken: {total_time:.2f} seconds")
print(f"FPS: {fps:.2f}")
# Entry point of the script
if __name__ == "__main__":
args = parse_args()
main(args)