controlnet_dev / model_compare_full /gen_compare_image.py
takuma104's picture
add model_compare_full
2e343db
import numpy as np
import torch
import sys
import os
from diffusers import (
StableDiffusionControlNetPipeline,
AutoencoderKL,
UNet2DConditionModel,
)
from diffusers.utils import load_image
test_prompt = "best quality, extremely detailed"
test_negative_prompt = "lowres, bad anatomy, worst quality, low quality"
def generate_image(seed, control):
image = pipe(
prompt=test_prompt,
negative_prompt=test_negative_prompt,
width=512,
height=512,
generator=torch.Generator(device="cuda").manual_seed(seed),
image=control,
num_inference_steps=30,
).images[0]
return image
if __name__ == "__main__":
contrlnet_model = sys.argv[1]
output_image_root_folder = f"./all_controlnet_models/{contrlnet_model}"
os.makedirs(output_image_root_folder, exist_ok=True)
model_id = f"../../control_sd15_{contrlnet_model}"
base_model_id = sys.argv[2] if len(sys.argv) == 3 else None
control_image_dict = {
"canny":"control_bird_canny.png",
"depth":"control_vermeer_depth.png",
"hed":"control_bird_hed.png",
"mlsd":"control_room_mlsd.png",
"normal":"control_human_normal.png",
"openpose":"control_human_openpose.png",
"scribble":"control_vermeer_scribble.png",
"seg":"control_room_seg.png",
}
control_image = load_image(
f"https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/{control_image_dict[contrlnet_model]}"
)
if base_model_id:
unet = UNet2DConditionModel.from_pretrained(base_model_id, subfolder="unet").to(
"cuda"
)
output_types = [base_model_id.split("/")[1]]
else:
output_types = ["sd15"]
for output_type in output_types:
if output_type == "sd15":
print("SD15 no override config")
pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id).to(
"cuda"
)
else:
print(f"{base_model_id} unet only override config")
pipe = StableDiffusionControlNetPipeline.from_pretrained(
model_id, unet=unet
).to("cuda")
output_folder = f"{output_image_root_folder}/{output_type}"
os.makedirs(output_folder, exist_ok=True)
for seed in range(16):
image = generate_image(seed=seed, control=control_image)
image.save(f"{output_folder}/output_{seed:02d}.png")