williamberman
commited on
Commit
•
d89243e
1
Parent(s):
2bbd7a0
update dep
Browse files- diffusion.py +17 -1
- sdxl.py +157 -165
- sdxl_models.py +45 -38
diffusion.py
CHANGED
@@ -16,11 +16,27 @@ def make_sigmas(beta_start=0.00085, beta_end=0.012, num_train_timesteps=default_
|
|
16 |
return sigmas
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
@torch.no_grad()
|
20 |
def rk_ode_solver_diffusion_loop(eps_theta, timesteps, sigmas, x_T, rk_steps_weights):
|
21 |
x_t = x_T
|
22 |
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
t = timesteps[i].unsqueeze(0)
|
25 |
sigma = sigmas[t]
|
26 |
|
|
|
16 |
return sigmas
|
17 |
|
18 |
|
19 |
+
_with_tqdm = False
|
20 |
+
|
21 |
+
|
22 |
+
def set_with_tqdm(it):
|
23 |
+
global _with_tqdm
|
24 |
+
|
25 |
+
_with_tqdm = it
|
26 |
+
|
27 |
+
|
28 |
@torch.no_grad()
|
29 |
def rk_ode_solver_diffusion_loop(eps_theta, timesteps, sigmas, x_T, rk_steps_weights):
|
30 |
x_t = x_T
|
31 |
|
32 |
+
iter_over = range(len(timesteps) - 1, -1, -1)
|
33 |
+
|
34 |
+
if _with_tqdm:
|
35 |
+
from tqdm import tqdm
|
36 |
+
|
37 |
+
iter_over = tqdm(iter_over)
|
38 |
+
|
39 |
+
for i in iter_over:
|
40 |
t = timesteps[i].unsqueeze(0)
|
41 |
sigma = sigmas[t]
|
42 |
|
sdxl.py
CHANGED
@@ -9,7 +9,6 @@ import torch
|
|
9 |
import torch.nn.functional as F
|
10 |
import torchvision.transforms
|
11 |
import torchvision.transforms.functional as TF
|
12 |
-
import webdataset as wds
|
13 |
from PIL import Image
|
14 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
15 |
from torch.utils.data import default_collate
|
@@ -348,6 +347,8 @@ class SDXLTraining:
|
|
348 |
|
349 |
|
350 |
def get_sdxl_dataset(train_shards: str, shuffle_buffer_size: int, batch_size: int, proportion_empty_prompts: float, get_sdxl_conditioning_images=None):
|
|
|
|
|
351 |
dataset = (
|
352 |
wds.WebDataset(
|
353 |
train_shards,
|
@@ -443,61 +444,85 @@ def get_random_crop_params(input_size: Tuple[int, int], output_size: Tuple[int,
|
|
443 |
return i, j, th, tw
|
444 |
|
445 |
|
446 |
-
def
|
447 |
resolution = image.width
|
448 |
|
449 |
-
|
450 |
-
conditioning_image = open_pose(image, detect_resolution=resolution, image_resolution=resolution, return_pil=False)
|
451 |
|
452 |
-
|
453 |
-
|
454 |
|
455 |
-
|
456 |
|
457 |
-
|
458 |
|
459 |
-
|
460 |
-
import cv2
|
461 |
|
462 |
-
conditioning_image = np.array(image)
|
463 |
-
conditioning_image = cv2.Canny(conditioning_image, 100, 200)
|
464 |
-
conditioning_image = conditioning_image[:, :, None]
|
465 |
-
conditioning_image = np.concatenate([conditioning_image, conditioning_image, conditioning_image], axis=2)
|
466 |
|
467 |
-
|
|
|
468 |
|
469 |
-
|
|
|
|
|
|
|
470 |
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
|
478 |
-
|
479 |
|
480 |
-
|
481 |
|
482 |
-
|
483 |
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
conditioning_image = conditioning_image * (conditioning_image_mask < 0.5)
|
489 |
|
490 |
-
|
491 |
|
492 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
else:
|
494 |
-
|
495 |
-
|
|
|
|
|
|
|
|
|
|
|
496 |
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
|
|
|
|
|
|
501 |
|
502 |
return dict(conditioning_image=conditioning_image, conditioning_image_mask=conditioning_image_mask, conditioning_image_as_pil=conditioning_image_as_pil)
|
503 |
|
@@ -830,102 +855,112 @@ def sdxl_eps_theta(
|
|
830 |
|
831 |
known_negative_prompt = "text, watermark, low-quality, signature, moiré pattern, downsampling, aliasing, distorted, blurry, glossy, blur, jpeg artifacts, compression artifacts, poorly drawn, low-resolution, bad, distortion, twisted, excessive, exaggerated pose, exaggerated limbs, grainy, symmetrical, duplicate, error, pattern, beginner, pixelated, fake, hyper, glitch, overexposed, high-contrast, bad-contrast"
|
832 |
|
|
|
|
|
833 |
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
):
|
853 |
-
if device is None:
|
854 |
if torch.cuda.is_available():
|
855 |
device = "cuda"
|
856 |
elif torch.backends.mps.is_available():
|
857 |
device = "mps"
|
858 |
|
859 |
-
if
|
|
|
|
|
860 |
text_encoder_one = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", variant="fp16", torch_dtype=torch.float16)
|
861 |
text_encoder_one.to(device=device)
|
862 |
|
863 |
-
if text_encoder_two is None:
|
864 |
text_encoder_two = CLIPTextModelWithProjection.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2", variant="fp16", torch_dtype=torch.float16)
|
865 |
text_encoder_two.to(device=device)
|
866 |
|
867 |
-
if vae is None:
|
868 |
vae = SDXLVae.load_fp16_fix(device=device)
|
|
|
869 |
|
870 |
-
if unet is None:
|
871 |
unet = SDXLUNet.load_fp16(device=device)
|
|
|
|
|
872 |
|
873 |
-
|
874 |
-
|
875 |
-
controlnet = SDXLControlNet.load(controlnet_checkpoint, device=device, dtype=torch.float16)
|
876 |
-
elif controlnet == "SDXLControlNetFull":
|
877 |
-
controlnet = SDXLControlNetFull.load(controlnet_checkpoint, device=device, dtype=torch.float16)
|
878 |
-
elif controlnet == "SDXLControlNetPreEncodedControlnetCond":
|
879 |
-
controlnet = SDXLControlNetPreEncodedControlnetCond.load(controlnet_checkpoint, device=device, dtype=torch.float16)
|
880 |
-
else:
|
881 |
-
assert False
|
882 |
|
883 |
-
|
884 |
-
|
|
|
|
|
|
|
|
|
885 |
else:
|
886 |
-
|
887 |
|
888 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
889 |
|
890 |
-
|
|
|
|
|
|
|
|
|
891 |
|
892 |
-
|
893 |
-
if not isinstance(images, list):
|
894 |
-
images = [images]
|
895 |
|
896 |
-
|
897 |
-
masks = [masks]
|
898 |
|
899 |
-
|
|
|
|
|
900 |
|
901 |
-
|
902 |
-
|
903 |
-
image = Image.open(image)
|
904 |
-
image = image.convert("RGB")
|
905 |
-
image = image.resize((1024, 1024))
|
906 |
-
elif isinstance(image, Image.Image):
|
907 |
-
...
|
908 |
-
else:
|
909 |
-
assert False
|
910 |
|
911 |
-
|
912 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
913 |
|
914 |
-
|
915 |
-
|
916 |
-
image = image[:, :, None]
|
917 |
-
controlnet_image = np.concatenate([controlnet_image, controlnet_image, controlnet_image], axis=2)
|
918 |
|
|
|
|
|
|
|
|
|
919 |
image = TF.to_tensor(image)
|
920 |
|
921 |
-
if masks is not None:
|
922 |
-
mask = masks[image_idx]
|
923 |
-
|
924 |
-
mask = Image.open(mask)
|
925 |
-
elif isinstance(mask, Image.Image):
|
926 |
-
...
|
927 |
-
else:
|
928 |
-
assert False
|
929 |
mask = mask.convert("L")
|
930 |
mask = mask.resize((1024, 1024))
|
931 |
mask = TF.to_tensor(mask)
|
@@ -933,83 +968,40 @@ def gen_sdxl_simplified_interface(
|
|
933 |
if isinstance(controlnet, SDXLControlNetPreEncodedControlnetCond):
|
934 |
image = image * (mask < 0.5)
|
935 |
image = TF.normalize(image, [0.5], [0.5])
|
936 |
-
image = vae.encode(image[None, :, :, :].to(dtype=vae.dtype, device=vae.device)).to(dtype=
|
937 |
mask = TF.resize(mask, (1024 // 8, 1024 // 8))[None, :, :, :].to(dtype=image.dtype, device=image.device)
|
938 |
image = torch.concat((image, mask), dim=1)
|
939 |
else:
|
940 |
-
image = (image * (mask < 0.5) + -1.0 * (mask >= 0.5)).to(dtype=
|
941 |
image = image[None, :, :, :]
|
942 |
|
943 |
-
|
944 |
|
945 |
-
|
946 |
else:
|
947 |
-
|
948 |
|
949 |
-
if
|
950 |
-
|
951 |
-
prompts_ = []
|
952 |
-
for prompt in prompts:
|
953 |
-
prompts_ += [prompt] * num_images
|
954 |
-
|
955 |
-
if negative_prompts is not None:
|
956 |
-
if isinstance(negative_prompts, str):
|
957 |
-
negative_prompts = [negative_prompts]
|
958 |
-
negative_prompts_ = []
|
959 |
-
for negative_prompt in negative_prompts:
|
960 |
-
negative_prompts_ += [negative_prompt] * num_images
|
961 |
else:
|
962 |
-
|
963 |
|
964 |
-
|
965 |
-
prompts=
|
966 |
-
negative_prompts=negative_prompts_,
|
967 |
unet=unet,
|
968 |
text_encoder_one=text_encoder_one,
|
969 |
text_encoder_two=text_encoder_two,
|
970 |
-
|
971 |
-
timesteps=timesteps,
|
972 |
controlnet=controlnet,
|
973 |
adapter=adapter,
|
974 |
-
|
975 |
-
|
|
|
|
|
|
|
976 |
)
|
977 |
|
978 |
-
|
979 |
-
x_0 = vae.output_tensor_to_pil(x_0)
|
980 |
-
|
981 |
-
return x_0
|
982 |
-
|
983 |
-
|
984 |
-
if __name__ == "__main__":
|
985 |
-
from argparse import ArgumentParser
|
986 |
-
|
987 |
-
args = ArgumentParser()
|
988 |
-
args.add_argument("--prompt", required=True, type=str)
|
989 |
-
args.add_argument("--num_images", required=True, type=int, default=1)
|
990 |
-
args.add_argument("--num_inference_steps", required=False, type=int, default=50)
|
991 |
-
args.add_argument("--image", required=False, type=str, default=None)
|
992 |
-
args.add_argument("--mask", required=False, type=str, default=None)
|
993 |
-
args.add_argument("--controlnet_checkpoint", required=False, type=str, default=None)
|
994 |
-
args.add_argument("--controlnet", required=False, choices=["SDXLControlNet", "SDXLControlNetFull", "SDXLControNetPreEncodedControlnetCond"], default=None)
|
995 |
-
args.add_argument("--adapter_checkpoint", required=False, type=str, default=None)
|
996 |
-
args.add_argument("--apply_conditioning", choices=["canny"], required=False, default=None)
|
997 |
-
args.add_argument("--device", required=False, default=None)
|
998 |
-
args = args.parse_args()
|
999 |
-
|
1000 |
-
images = gen_sdxl_simplified_interface(
|
1001 |
-
prompt=args.prompt,
|
1002 |
-
num_images=args.num_images,
|
1003 |
-
num_inference_steps=args.num_inference_steps,
|
1004 |
-
images=[args.image],
|
1005 |
-
masks=[args.mask],
|
1006 |
-
controlnet_checkpoint=args.controlnet_checkpoint,
|
1007 |
-
controlnet=args.controlnet,
|
1008 |
-
adapter_checkpoint=args.adapter_checkpoint,
|
1009 |
-
apply_conditioning=args.apply_conditioning,
|
1010 |
-
device=args.device,
|
1011 |
-
negative_prompt=known_negative_prompt,
|
1012 |
-
)
|
1013 |
|
1014 |
for i, image in enumerate(images):
|
1015 |
image.save(f"out_{i}.png")
|
|
|
9 |
import torch.nn.functional as F
|
10 |
import torchvision.transforms
|
11 |
import torchvision.transforms.functional as TF
|
|
|
12 |
from PIL import Image
|
13 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
14 |
from torch.utils.data import default_collate
|
|
|
347 |
|
348 |
|
349 |
def get_sdxl_dataset(train_shards: str, shuffle_buffer_size: int, batch_size: int, proportion_empty_prompts: float, get_sdxl_conditioning_images=None):
|
350 |
+
import webdataset as wds
|
351 |
+
|
352 |
dataset = (
|
353 |
wds.WebDataset(
|
354 |
train_shards,
|
|
|
444 |
return i, j, th, tw
|
445 |
|
446 |
|
447 |
+
def get_adapter_openpose_conditioning_image(image, open_pose):
|
448 |
resolution = image.width
|
449 |
|
450 |
+
conditioning_image = open_pose(image, detect_resolution=resolution, image_resolution=resolution, return_pil=False)
|
|
|
451 |
|
452 |
+
if (conditioning_image == 0).all():
|
453 |
+
return None, None
|
454 |
|
455 |
+
conditioning_image_as_pil = Image.fromarray(conditioning_image)
|
456 |
|
457 |
+
conditioning_image = TF.to_tensor(conditioning_image)
|
458 |
|
459 |
+
return dict(conditioning_image=conditioning_image, conditioning_image_as_pil=conditioning_image_as_pil)
|
|
|
460 |
|
|
|
|
|
|
|
|
|
461 |
|
462 |
+
def get_controlnet_canny_conditioning_image(image):
|
463 |
+
import cv2
|
464 |
|
465 |
+
conditioning_image = np.array(image)
|
466 |
+
conditioning_image = cv2.Canny(conditioning_image, 100, 200)
|
467 |
+
conditioning_image = conditioning_image[:, :, None]
|
468 |
+
conditioning_image = np.concatenate([conditioning_image, conditioning_image, conditioning_image], axis=2)
|
469 |
|
470 |
+
conditioning_image_as_pil = Image.fromarray(conditioning_image)
|
471 |
+
|
472 |
+
conditioning_image = TF.to_tensor(conditioning_image)
|
473 |
+
|
474 |
+
return dict(conditioning_image=conditioning_image, conditioning_image_as_pil=conditioning_image_as_pil)
|
475 |
+
|
476 |
+
|
477 |
+
def get_controlnet_pre_encoded_controlnet_inpainting_conditioning_image(image, conditioning_image_mask):
|
478 |
+
resolution = image.width
|
479 |
+
|
480 |
+
if conditioning_image_mask is None:
|
481 |
+
if random.random() <= 0.25:
|
482 |
+
conditioning_image_mask = np.ones((resolution, resolution), np.float32)
|
483 |
+
else:
|
484 |
+
conditioning_image_mask = random.choice([make_random_rectangle_mask, make_random_irregular_mask, make_outpainting_mask])(resolution, resolution)
|
485 |
|
486 |
+
conditioning_image_mask = torch.from_numpy(conditioning_image_mask)
|
487 |
|
488 |
+
conditioning_image_mask = conditioning_image_mask[None, :, :]
|
489 |
|
490 |
+
conditioning_image = TF.to_tensor(image)
|
491 |
|
492 |
+
# where mask is 1, zero out the pixels. Note that this requires mask to be concattenated
|
493 |
+
# with the mask so that the network knows the zeroed out pixels are from the mask and
|
494 |
+
# are not just zero in the original image
|
495 |
+
conditioning_image = conditioning_image * (conditioning_image_mask < 0.5)
|
|
|
496 |
|
497 |
+
conditioning_image_as_pil = TF.to_pil_image(conditioning_image)
|
498 |
|
499 |
+
conditioning_image = TF.normalize(conditioning_image, [0.5], [0.5])
|
500 |
+
|
501 |
+
return dict(conditioning_image=conditioning_image, conditioning_image_mask=conditioning_image_mask, conditioning_image_as_pil=conditioning_image_as_pil)
|
502 |
+
|
503 |
+
|
504 |
+
def get_controlnet_inpainting_conditioning_image(image, conditioning_image_mask):
|
505 |
+
resolution = image.width
|
506 |
+
|
507 |
+
if conditioning_image_mask is None:
|
508 |
+
if random.random() <= 0.25:
|
509 |
+
conditioning_image_mask = np.ones((resolution, resolution), np.float32)
|
510 |
else:
|
511 |
+
conditioning_image_mask = random.choice([make_random_rectangle_mask, make_random_irregular_mask, make_outpainting_mask])(resolution, resolution)
|
512 |
+
|
513 |
+
conditioning_image_mask = torch.from_numpy(conditioning_image_mask)
|
514 |
+
|
515 |
+
conditioning_image_mask = conditioning_image_mask[None, :, :]
|
516 |
+
|
517 |
+
conditioning_image = TF.to_tensor(image)
|
518 |
|
519 |
+
# Just zero out the pixels which will be masked
|
520 |
+
conditioning_image_as_pil = TF.to_pil_image(conditioning_image * (conditioning_image_mask < 0.5))
|
521 |
+
|
522 |
+
# where mask is set to 1, set to -1 "special" masked image pixel.
|
523 |
+
# -1 is outside of the 0-1 range that the controlnet normalized
|
524 |
+
# input is in.
|
525 |
+
conditioning_image = conditioning_image * (conditioning_image_mask < 0.5) + -1.0 * (conditioning_image_mask >= 0.5)
|
526 |
|
527 |
return dict(conditioning_image=conditioning_image, conditioning_image_mask=conditioning_image_mask, conditioning_image_as_pil=conditioning_image_as_pil)
|
528 |
|
|
|
855 |
|
856 |
known_negative_prompt = "text, watermark, low-quality, signature, moiré pattern, downsampling, aliasing, distorted, blurry, glossy, blur, jpeg artifacts, compression artifacts, poorly drawn, low-resolution, bad, distortion, twisted, excessive, exaggerated pose, exaggerated limbs, grainy, symmetrical, duplicate, error, pattern, beginner, pixelated, fake, hyper, glitch, overexposed, high-contrast, bad-contrast"
|
857 |
|
858 |
+
if __name__ == "__main__":
|
859 |
+
from argparse import ArgumentParser
|
860 |
|
861 |
+
args = ArgumentParser()
|
862 |
+
args.add_argument("--prompts", required=True, type=str, nargs="+")
|
863 |
+
args.add_argument("--negative_prompts", required=False, type=str, nargs="+")
|
864 |
+
args.add_argument("--use_known_negative_prompt", action="store_true")
|
865 |
+
args.add_argument("--num_images_per_prompt", required=True, type=int, default=1)
|
866 |
+
args.add_argument("--num_inference_steps", required=False, type=int, default=50)
|
867 |
+
args.add_argument("--images", required=False, type=str, default=None, nargs="+")
|
868 |
+
args.add_argument("--masks", required=False, type=str, default=None, nargs="+")
|
869 |
+
args.add_argument("--controlnet_checkpoint", required=False, type=str, default=None)
|
870 |
+
args.add_argument("--controlnet", required=False, choices=["SDXLControlNet", "SDXLControlNetFull", "SDXLControNetPreEncodedControlnetCond"], default=None)
|
871 |
+
args.add_argument("--adapter_checkpoint", required=False, type=str, default=None)
|
872 |
+
args.add_argument("--device", required=False, default=None)
|
873 |
+
args.add_argument("--dtype", required=False, default="fp16", choices=["fp16", "fp32"])
|
874 |
+
args.add_argument("--guidance_scale", required=False, default=5.0, type=float)
|
875 |
+
args.add_argument("--seed", required=False, type=int)
|
876 |
+
args = args.parse_args()
|
877 |
+
|
878 |
+
if args.device is None:
|
|
|
|
|
879 |
if torch.cuda.is_available():
|
880 |
device = "cuda"
|
881 |
elif torch.backends.mps.is_available():
|
882 |
device = "mps"
|
883 |
|
884 |
+
if args.dtype == "fp16":
|
885 |
+
dtype = torch.float16
|
886 |
+
|
887 |
text_encoder_one = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", variant="fp16", torch_dtype=torch.float16)
|
888 |
text_encoder_one.to(device=device)
|
889 |
|
|
|
890 |
text_encoder_two = CLIPTextModelWithProjection.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2", variant="fp16", torch_dtype=torch.float16)
|
891 |
text_encoder_two.to(device=device)
|
892 |
|
|
|
893 |
vae = SDXLVae.load_fp16_fix(device=device)
|
894 |
+
vae.to(torch.float16)
|
895 |
|
|
|
896 |
unet = SDXLUNet.load_fp16(device=device)
|
897 |
+
elif args.dtype == "fp32":
|
898 |
+
dtype = torch.float32
|
899 |
|
900 |
+
text_encoder_one = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
|
901 |
+
text_encoder_one.to(device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
902 |
|
903 |
+
text_encoder_two = CLIPTextModelWithProjection.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2")
|
904 |
+
text_encoder_two.to(device=device)
|
905 |
+
|
906 |
+
vae = SDXLVae.load_fp16_fix(device=device)
|
907 |
+
|
908 |
+
unet = SDXLUNet.load_fp32(device=device)
|
909 |
else:
|
910 |
+
assert False
|
911 |
|
912 |
+
if args.controlnet == "SDXLControlNet":
|
913 |
+
controlnet = SDXLControlNet.load(args.controlnet_checkpoint, device=device)
|
914 |
+
controlnet.to(dtype)
|
915 |
+
elif args.controlnet == "SDXLControlNetFull":
|
916 |
+
controlnet = SDXLControlNetFull.load(args.controlnet_checkpoint, device=device)
|
917 |
+
controlnet.to(dtype)
|
918 |
+
elif args.controlnet == "SDXLControlNetPreEncodedControlnetCond":
|
919 |
+
controlnet = SDXLControlNetPreEncodedControlnetCond.load(args.controlnet_checkpoint, device=device)
|
920 |
+
controlnet.to(dtype)
|
921 |
+
else:
|
922 |
+
controlnet = None
|
923 |
|
924 |
+
if args.adapter_checkpoint is not None:
|
925 |
+
adapter = SDXLAdapter.load(args.adapter_checkpoint, device=device)
|
926 |
+
adapter.to(dtype)
|
927 |
+
else:
|
928 |
+
adapter = None
|
929 |
|
930 |
+
sigmas = make_sigmas(device=device).to(unet.dtype)
|
|
|
|
|
931 |
|
932 |
+
timesteps = torch.linspace(0, sigmas.numel() - 1, args.num_inference_steps, dtype=torch.long, device=unet.device)
|
|
|
933 |
|
934 |
+
prompts = []
|
935 |
+
for prompt in args.prompts:
|
936 |
+
prompts += [prompt] * args.num_images_per_prompt
|
937 |
|
938 |
+
if args.use_known_negative_prompt:
|
939 |
+
args.negative_prompts = [known_negative_prompt]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
940 |
|
941 |
+
if args.negative_prompts is None:
|
942 |
+
negative_prompts = None
|
943 |
+
elif len(args.negative_prompts) == 1:
|
944 |
+
negative_prompts = args.negative_prompts * len(prompts)
|
945 |
+
elif len(args.negative_prompts) == len(args.prompts):
|
946 |
+
negative_prompts = []
|
947 |
+
for negative_prompt in args.negative_prompts:
|
948 |
+
negative_prompts += [negative_prompt] * args.num_images_per_prompt
|
949 |
+
else:
|
950 |
+
assert False
|
951 |
|
952 |
+
if args.images is not None:
|
953 |
+
images = []
|
|
|
|
|
954 |
|
955 |
+
for image_idx, image in enumerate(args.images):
|
956 |
+
image = Image.open(image)
|
957 |
+
image = image.convert("RGB")
|
958 |
+
image = image.resize((1024, 1024))
|
959 |
image = TF.to_tensor(image)
|
960 |
|
961 |
+
if args.masks is not None:
|
962 |
+
mask = args.masks[image_idx]
|
963 |
+
mask = Image.open(mask)
|
|
|
|
|
|
|
|
|
|
|
964 |
mask = mask.convert("L")
|
965 |
mask = mask.resize((1024, 1024))
|
966 |
mask = TF.to_tensor(mask)
|
|
|
968 |
if isinstance(controlnet, SDXLControlNetPreEncodedControlnetCond):
|
969 |
image = image * (mask < 0.5)
|
970 |
image = TF.normalize(image, [0.5], [0.5])
|
971 |
+
image = vae.encode(image[None, :, :, :].to(dtype=vae.dtype, device=vae.device)).to(dtype=controlnet.dtype, device=controlnet.device)
|
972 |
mask = TF.resize(mask, (1024 // 8, 1024 // 8))[None, :, :, :].to(dtype=image.dtype, device=image.device)
|
973 |
image = torch.concat((image, mask), dim=1)
|
974 |
else:
|
975 |
+
image = (image * (mask < 0.5) + -1.0 * (mask >= 0.5)).to(dtype=dtype, device=device)
|
976 |
image = image[None, :, :, :]
|
977 |
|
978 |
+
images += [image] * args.num_images_per_prompt
|
979 |
|
980 |
+
images = torch.concat(images)
|
981 |
else:
|
982 |
+
images = None
|
983 |
|
984 |
+
if args.seed is None:
|
985 |
+
generator = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
986 |
else:
|
987 |
+
generator = torch.Generator(device).manual_seed(args.seed)
|
988 |
|
989 |
+
images = sdxl_diffusion_loop(
|
990 |
+
prompts=prompts,
|
|
|
991 |
unet=unet,
|
992 |
text_encoder_one=text_encoder_one,
|
993 |
text_encoder_two=text_encoder_two,
|
994 |
+
images=images,
|
|
|
995 |
controlnet=controlnet,
|
996 |
adapter=adapter,
|
997 |
+
sigmas=sigmas,
|
998 |
+
timesteps=timesteps,
|
999 |
+
guidance_scale=args.guidance_scale,
|
1000 |
+
negative_prompts=negative_prompts,
|
1001 |
+
generator=generator,
|
1002 |
)
|
1003 |
|
1004 |
+
images = vae.output_tensor_to_pil(vae.decode(images))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1005 |
|
1006 |
for i, image in enumerate(images):
|
1007 |
image.save(f"out_{i}.png")
|
sdxl_models.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
import math
|
2 |
import os
|
3 |
-
from typing import List, Literal, Optional
|
4 |
|
5 |
import safetensors.torch
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
8 |
import torchvision.transforms.functional as TF
|
9 |
-
import xformers.ops
|
10 |
from PIL import Image
|
11 |
from torch import nn
|
12 |
|
@@ -21,12 +20,14 @@ class ModelUtils:
|
|
21 |
return next(self.parameters()).device
|
22 |
|
23 |
@classmethod
|
24 |
-
def load(cls, load_from: str, device, overrides: Optional[List[str]] = None):
|
25 |
import load_state_dict_patch
|
26 |
|
27 |
load_from = [load_from]
|
28 |
|
29 |
if overrides is not None:
|
|
|
|
|
30 |
load_from += overrides
|
31 |
|
32 |
state_dict = {}
|
@@ -1323,51 +1324,57 @@ class TransformerDecoderBlock(nn.Module):
|
|
1323 |
return hidden_states
|
1324 |
|
1325 |
|
1326 |
-
|
1327 |
-
attention_implementation: Literal["xformers", "torch_2.0_scaled_dot_product"] = "xformers"
|
1328 |
|
1329 |
-
@classmethod
|
1330 |
-
def attention(cls, to_q, to_k, to_v, to_out, head_dim, hidden_states, encoder_hidden_states=None):
|
1331 |
-
batch_size, q_seq_len, channels = hidden_states.shape
|
1332 |
|
1333 |
-
|
1334 |
-
|
1335 |
-
|
1336 |
-
kv = hidden_states
|
1337 |
|
1338 |
-
kv_seq_len = kv.shape[1]
|
1339 |
|
1340 |
-
|
1341 |
-
|
1342 |
-
value = to_v(kv)
|
1343 |
|
1344 |
-
|
1345 |
-
|
1346 |
-
|
1347 |
-
|
1348 |
|
1349 |
-
|
1350 |
|
1351 |
-
|
1352 |
-
|
1353 |
-
|
1354 |
-
query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
|
1355 |
-
key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
|
1356 |
-
value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
|
1357 |
|
1358 |
-
|
|
|
1359 |
|
1360 |
-
|
1361 |
-
|
1362 |
-
|
1363 |
-
assert False
|
1364 |
|
1365 |
-
hidden_states =
|
1366 |
|
1367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1368 |
|
1369 |
|
1370 |
-
class Attention(nn.Module
|
1371 |
def __init__(self, channels, encoder_hidden_states_dim):
|
1372 |
super().__init__()
|
1373 |
self.to_q = nn.Linear(channels, channels, bias=False)
|
@@ -1376,10 +1383,10 @@ class Attention(nn.Module, AttentionMixin):
|
|
1376 |
self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
|
1377 |
|
1378 |
def forward(self, hidden_states, encoder_hidden_states=None):
|
1379 |
-
return
|
1380 |
|
1381 |
|
1382 |
-
class VaeMidBlockAttention(nn.Module
|
1383 |
def __init__(self, channels):
|
1384 |
super().__init__()
|
1385 |
self.group_norm = nn.GroupNorm(32, channels, eps=1e-06)
|
@@ -1397,7 +1404,7 @@ class VaeMidBlockAttention(nn.Module, AttentionMixin):
|
|
1397 |
|
1398 |
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1399 |
|
1400 |
-
hidden_states =
|
1401 |
|
1402 |
hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
|
1403 |
|
|
|
1 |
import math
|
2 |
import os
|
3 |
+
from typing import List, Literal, Optional, Union
|
4 |
|
5 |
import safetensors.torch
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
8 |
import torchvision.transforms.functional as TF
|
|
|
9 |
from PIL import Image
|
10 |
from torch import nn
|
11 |
|
|
|
20 |
return next(self.parameters()).device
|
21 |
|
22 |
@classmethod
|
23 |
+
def load(cls, load_from: str, device, overrides: Optional[Union[str, List[str]]] = None):
|
24 |
import load_state_dict_patch
|
25 |
|
26 |
load_from = [load_from]
|
27 |
|
28 |
if overrides is not None:
|
29 |
+
if isinstance(overrides, str):
|
30 |
+
overrides = [overrides]
|
31 |
load_from += overrides
|
32 |
|
33 |
state_dict = {}
|
|
|
1324 |
return hidden_states
|
1325 |
|
1326 |
|
1327 |
+
_attention_implementation: Literal["xformers", "torch_2.0_scaled_dot_product"] = "torch_2.0_scaled_dot_product"
|
|
|
1328 |
|
|
|
|
|
|
|
1329 |
|
1330 |
+
def set_attention_implementation(impl: Literal["xformers", "torch_2.0_scaled_dot_product"]):
|
1331 |
+
global _attention_implementation
|
1332 |
+
_attention_implementation = impl
|
|
|
1333 |
|
|
|
1334 |
|
1335 |
+
def attention(to_q, to_k, to_v, to_out, head_dim, hidden_states, encoder_hidden_states=None):
|
1336 |
+
batch_size, q_seq_len, channels = hidden_states.shape
|
|
|
1337 |
|
1338 |
+
if encoder_hidden_states is not None:
|
1339 |
+
kv = encoder_hidden_states
|
1340 |
+
else:
|
1341 |
+
kv = hidden_states
|
1342 |
|
1343 |
+
kv_seq_len = kv.shape[1]
|
1344 |
|
1345 |
+
query = to_q(hidden_states)
|
1346 |
+
key = to_k(kv)
|
1347 |
+
value = to_v(kv)
|
|
|
|
|
|
|
1348 |
|
1349 |
+
if _attention_implementation == "xformers":
|
1350 |
+
import xformers.ops
|
1351 |
|
1352 |
+
query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).contiguous()
|
1353 |
+
key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
|
1354 |
+
value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
|
|
|
1355 |
|
1356 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value)
|
1357 |
|
1358 |
+
hidden_states = hidden_states.to(query.dtype)
|
1359 |
+
hidden_states = hidden_states.reshape(batch_size, q_seq_len, channels).contiguous()
|
1360 |
+
elif _attention_implementation == "torch_2.0_scaled_dot_product":
|
1361 |
+
query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
|
1362 |
+
key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
|
1363 |
+
value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
|
1364 |
+
|
1365 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value)
|
1366 |
+
|
1367 |
+
hidden_states = hidden_states.to(query.dtype)
|
1368 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, q_seq_len, channels).contiguous()
|
1369 |
+
else:
|
1370 |
+
assert False
|
1371 |
+
|
1372 |
+
hidden_states = to_out(hidden_states)
|
1373 |
+
|
1374 |
+
return hidden_states
|
1375 |
|
1376 |
|
1377 |
+
class Attention(nn.Module):
|
1378 |
def __init__(self, channels, encoder_hidden_states_dim):
|
1379 |
super().__init__()
|
1380 |
self.to_q = nn.Linear(channels, channels, bias=False)
|
|
|
1383 |
self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
|
1384 |
|
1385 |
def forward(self, hidden_states, encoder_hidden_states=None):
|
1386 |
+
return attention(self.to_q, self.to_k, self.to_v, self.to_out, 64, hidden_states, encoder_hidden_states)
|
1387 |
|
1388 |
|
1389 |
+
class VaeMidBlockAttention(nn.Module):
|
1390 |
def __init__(self, channels):
|
1391 |
super().__init__()
|
1392 |
self.group_norm = nn.GroupNorm(32, channels, eps=1e-06)
|
|
|
1404 |
|
1405 |
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1406 |
|
1407 |
+
hidden_states = attention(self.to_q, self.to_k, self.to_v, self.to_out, self.head_dim, hidden_states)
|
1408 |
|
1409 |
hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
|
1410 |
|