controlnet_dev / model_compare /gen_compare_image.py
takuma104's picture
add scripts and markdown table
e75ac26
raw
history blame
2.34 kB
import numpy as np
import torch
import sys
import os
from diffusers import (
StableDiffusionControlNetPipeline,
AutoencoderKL,
UNet2DConditionModel,
)
from diffusers.utils import load_image
test_prompt = "best quality, extremely detailed"
test_negative_prompt = "lowres, bad anatomy, worst quality, low quality"
def generate_image(seed, control):
image = pipe(
prompt=test_prompt,
negative_prompt=test_negative_prompt,
width=512,
height=512,
generator=torch.Generator(device="cuda").manual_seed(seed),
image=control,
).images[0]
return image
if __name__ == "__main__":
output_image_root_folder = "./canny"
model_id = f"../../control_sd15_canny"
base_model_id = sys.argv[1] if len(sys.argv) == 2 else None
canny_edged_image = load_image(
"https://huggingface.co/takuma104/controlnet_dev/resolve/main/vermeer_canny_edged.png"
)
if base_model_id:
unet = UNet2DConditionModel.from_pretrained(base_model_id, subfolder="unet").to(
"cuda"
)
vae = AutoencoderKL.from_pretrained(base_model_id, subfolder="vae").to("cuda")
output_types = [
base_model_id.split("/")[1] + suffix for suffix in ["_unet", "_unet_vae"]
]
else:
output_types = ["sd15"]
for output_type in output_types:
if output_type == "sd15":
print("SD15 no override config")
pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id).to(
"cuda"
)
elif output_type.endswith("_unet"):
print(f"{base_model_id} unet only override config")
pipe = StableDiffusionControlNetPipeline.from_pretrained(
model_id, unet=unet
).to("cuda")
elif output_type.endswith("_unet_vae"):
print(f"{base_model_id} unet & vae override config")
pipe = StableDiffusionControlNetPipeline.from_pretrained(
model_id, unet=unet, vae=vae
).to("cuda")
output_folder = f"{output_image_root_folder}/{output_type}"
os.makedirs(output_folder, exist_ok=True)
for seed in range(32):
image = generate_image(seed=seed, control=canny_edged_image)
image.save(f"{output_folder}/output_{seed:02d}.png")