|
import sys |
|
|
|
import torch |
|
import os |
|
import json |
|
import argparse |
|
sys.path.append(os.getcwd()) |
|
|
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
import numpy as np |
|
import glob |
|
|
|
|
|
from diffusers import StableDiffusionPipeline, DDIMScheduler |
|
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler |
|
from diffusers.utils import load_image |
|
from diffusers import StableDiffusionPipeline, DDPMPipeline, DDIMPipeline, PNDMPipeline, PNDMLMPipeline, DDPMLMPipeline, DPMLMPipeline, UniPCPipeline, LDMPipeline, PNDMScheduler, UniPCMultistepScheduler,DDIMScheduler |
|
from scheduler.scheduling_dpmsolver_multistep_lm import DPMSolverMultistepLMScheduler |
|
from scheduler.scheduling_ddim_lm import DDIMLMScheduler |
|
|
|
import cv2 |
|
import numpy as np |
|
|
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="sampling script for ControlNet-canny.") |
|
parser.add_argument('--seed', type=int, default=1) |
|
parser.add_argument('--num_inference_steps', type=int, default=20) |
|
parser.add_argument('--guidance', type=float, default=7.5) |
|
parser.add_argument('--sampler_type', type = str,default='lag') |
|
parser.add_argument('--prompt', type=str, default='an asian girl') |
|
parser.add_argument('--original_image_path', type=str, default="/xxx/xxx/data/input_image_vermeer.png") |
|
parser.add_argument('--lamb', type=float, default=5.0) |
|
parser.add_argument('--kappa', type=float, default=0.0) |
|
parser.add_argument('--freeze', type=float, default=0.0) |
|
|
|
|
|
parser.add_argument('--save_dir', type=str, default='/xxx/xxx/result/0402') |
|
parser.add_argument('--controlnet_dir', type=str, default="/xxx/xxx/sd-controlnet-canny") |
|
parser.add_argument('--sd_dir', type=str, default="/xxx/xxx/stable-diffusion-v1-5") |
|
|
|
|
|
|
|
args = parser.parse_args() |
|
if args.sampler_type in ['bdia']: |
|
parser.add_argument('--bdia_gamma', type=float, default=0.5) |
|
if args.sampler_type in ['edict']: |
|
parser.add_argument('--edict_p', type=float, default=0.93) |
|
args = parser.parse_args() |
|
device = 'cuda' |
|
sampler_type = args.sampler_type |
|
guidance_scale = args.guidance |
|
num_inference_steps = args.num_inference_steps |
|
lamb = args.lamb |
|
freeze = args.freeze |
|
kappa = args.kappa |
|
|
|
save_dir = args.save_dir |
|
if not os.path.exists(save_dir): |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
torch.manual_seed(args.seed) |
|
controlnet = ControlNetModel.from_pretrained(args.controlnet_dir, torch_dtype=torch.float16,use_safetensors=True) |
|
|
|
control_pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
args.sd_dir, controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True |
|
) |
|
control_pipe.enable_model_cpu_offload() |
|
control_pipe.safety_checker = None |
|
|
|
if sampler_type in ['dpm_lm']: |
|
control_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(control_pipe.scheduler.config) |
|
control_pipe.scheduler.config.solver_order = 3 |
|
control_pipe.scheduler.config.algorithm_type = "dpmsolver" |
|
control_pipe.scheduler.lamb = lamb |
|
control_pipe.scheduler.lm = True |
|
elif sampler_type in ['dpm']: |
|
control_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(control_pipe.scheduler.config) |
|
control_pipe.scheduler.config.solver_order = 3 |
|
control_pipe.scheduler.config.algorithm_type = "dpmsolver" |
|
control_pipe.scheduler.lamb = lamb |
|
control_pipe.scheduler.lm = False |
|
elif sampler_type in ['dpm++']: |
|
control_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(control_pipe.scheduler.config) |
|
control_pipe.scheduler.config.solver_order = 3 |
|
control_pipe.scheduler.config.algorithm_type = "dpmsolver++" |
|
control_pipe.scheduler.lamb = lamb |
|
control_pipe.scheduler.lm = False |
|
elif sampler_type in ['dpm++_lm']: |
|
control_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(control_pipe.scheduler.config) |
|
control_pipe.scheduler.config.solver_order = 3 |
|
control_pipe.scheduler.config.algorithm_type = "dpmsolver++" |
|
control_pipe.scheduler.lamb = lamb |
|
control_pipe.scheduler.lm = True |
|
elif sampler_type in ['pndm']: |
|
control_pipe.scheduler = PNDMScheduler.from_config(control_pipe.scheduler.config) |
|
elif sampler_type in ['ddim']: |
|
control_pipe.scheduler = DDIMScheduler.from_config(control_pipe.scheduler.config) |
|
elif sampler_type in ['ddim_lm']: |
|
control_pipe.scheduler = DDIMLMScheduler.from_config(control_pipe.scheduler.config) |
|
control_pipe.scheduler.lamb = lamb |
|
control_pipe.scheduler.lm = True |
|
control_pipe.scheduler.kappa = kappa |
|
control_pipe.scheduler.freeze = freeze |
|
elif sampler_type in ['unipc']: |
|
control_pipe.scheduler = UniPCMultistepScheduler.from_config(control_pipe.scheduler.config) |
|
|
|
original_image = load_image( |
|
args.original_image_path |
|
) |
|
image = np.array(original_image) |
|
low_threshold = 100 |
|
high_threshold = 200 |
|
|
|
image = cv2.Canny(image, low_threshold, high_threshold) |
|
image = image[:, :, None] |
|
image = np.concatenate([image, image, image], axis=2) |
|
canny_image = Image.fromarray(image) |
|
|
|
|
|
for prompt, negative_prompt in [['the mona lisa',''], |
|
['an asian girl',''], |
|
['an asian princess',''], |
|
['a portrait of a beautiful woman standing amidst a bed of vibrant tulips.',''], |
|
['a stunning Arabic woman dressed in traditional clothing',''], |
|
['a stunning Asian woman dressed in traditional clothing',''], |
|
['a stunning Black woman dressed in traditional clothing', ''], |
|
['a stunning German woman dressed in traditional clothing', ''], |
|
['a stunning Japan woman dressed in traditional clothing', ''], |
|
['a stunning Chinese woman dressed in traditional clothing', ''], |
|
['a stunning Jewish woman dressed in traditional clothing', ''], |
|
]: |
|
for seed in range(1): |
|
torch.manual_seed(seed) |
|
res = control_pipe( |
|
prompt=prompt, negative_prompt=negative_prompt, image=canny_image,num_inference_steps=num_inference_steps, |
|
).images[0] |
|
|
|
res.save(os.path.join(save_dir, |
|
f"{args.model}_{prompt[:20]}_seed{seed}_{sampler_type}_infer{num_inference_steps}_g{guidance_scale}_lamb{args.lamb}.png")) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |