williamberman commited on
Commit
d89243e
1 Parent(s): 2bbd7a0

update dep

Browse files
Files changed (3) hide show
  1. diffusion.py +17 -1
  2. sdxl.py +157 -165
  3. sdxl_models.py +45 -38
diffusion.py CHANGED
@@ -16,11 +16,27 @@ def make_sigmas(beta_start=0.00085, beta_end=0.012, num_train_timesteps=default_
16
  return sigmas
17
 
18
 
 
 
 
 
 
 
 
 
 
19
  @torch.no_grad()
20
  def rk_ode_solver_diffusion_loop(eps_theta, timesteps, sigmas, x_T, rk_steps_weights):
21
  x_t = x_T
22
 
23
- for i in range(len(timesteps) - 1, -1, -1):
 
 
 
 
 
 
 
24
  t = timesteps[i].unsqueeze(0)
25
  sigma = sigmas[t]
26
 
 
16
  return sigmas
17
 
18
 
19
+ _with_tqdm = False
20
+
21
+
22
+ def set_with_tqdm(it):
23
+ global _with_tqdm
24
+
25
+ _with_tqdm = it
26
+
27
+
28
  @torch.no_grad()
29
  def rk_ode_solver_diffusion_loop(eps_theta, timesteps, sigmas, x_T, rk_steps_weights):
30
  x_t = x_T
31
 
32
+ iter_over = range(len(timesteps) - 1, -1, -1)
33
+
34
+ if _with_tqdm:
35
+ from tqdm import tqdm
36
+
37
+ iter_over = tqdm(iter_over)
38
+
39
+ for i in iter_over:
40
  t = timesteps[i].unsqueeze(0)
41
  sigma = sigmas[t]
42
 
sdxl.py CHANGED
@@ -9,7 +9,6 @@ import torch
9
  import torch.nn.functional as F
10
  import torchvision.transforms
11
  import torchvision.transforms.functional as TF
12
- import webdataset as wds
13
  from PIL import Image
14
  from torch.nn.parallel import DistributedDataParallel as DDP
15
  from torch.utils.data import default_collate
@@ -348,6 +347,8 @@ class SDXLTraining:
348
 
349
 
350
  def get_sdxl_dataset(train_shards: str, shuffle_buffer_size: int, batch_size: int, proportion_empty_prompts: float, get_sdxl_conditioning_images=None):
 
 
351
  dataset = (
352
  wds.WebDataset(
353
  train_shards,
@@ -443,61 +444,85 @@ def get_random_crop_params(input_size: Tuple[int, int], output_size: Tuple[int,
443
  return i, j, th, tw
444
 
445
 
446
- def get_sdxl_conditioning_images(image, adapter_type=None, controlnet_type=None, controlnet_variant=None, open_pose=None, conditioning_image_mask=None):
447
  resolution = image.width
448
 
449
- if adapter_type == "openpose":
450
- conditioning_image = open_pose(image, detect_resolution=resolution, image_resolution=resolution, return_pil=False)
451
 
452
- if (conditioning_image == 0).all():
453
- return None, None
454
 
455
- conditioning_image_as_pil = Image.fromarray(conditioning_image)
456
 
457
- conditioning_image = TF.to_tensor(conditioning_image)
458
 
459
- if controlnet_type == "canny":
460
- import cv2
461
 
462
- conditioning_image = np.array(image)
463
- conditioning_image = cv2.Canny(conditioning_image, 100, 200)
464
- conditioning_image = conditioning_image[:, :, None]
465
- conditioning_image = np.concatenate([conditioning_image, conditioning_image, conditioning_image], axis=2)
466
 
467
- conditioning_image_as_pil = Image.fromarray(conditioning_image)
 
468
 
469
- conditioning_image = TF.to_tensor(conditioning_image)
 
 
 
470
 
471
- if controlnet_type == "inpainting":
472
- if conditioning_image_mask is None:
473
- if random.random() <= 0.25:
474
- conditioning_image_mask = np.ones((resolution, resolution), np.float32)
475
- else:
476
- conditioning_image_mask = random.choice([make_random_rectangle_mask, make_random_irregular_mask, make_outpainting_mask])(resolution, resolution)
 
 
 
 
 
 
 
 
 
477
 
478
- conditioning_image_mask = torch.from_numpy(conditioning_image_mask)
479
 
480
- conditioning_image_mask = conditioning_image_mask[None, :, :]
481
 
482
- conditioning_image = TF.to_tensor(image)
483
 
484
- if controlnet_variant == "pre_encoded_controlnet_cond":
485
- # where mask is 1, zero out the pixels. Note that this requires mask to be concattenated
486
- # with the mask so that the network knows the zeroed out pixels are from the mask and
487
- # are not just zero in the original image
488
- conditioning_image = conditioning_image * (conditioning_image_mask < 0.5)
489
 
490
- conditioning_image_as_pil = TF.to_pil_image(conditioning_image)
491
 
492
- conditioning_image = TF.normalize(conditioning_image, [0.5], [0.5])
 
 
 
 
 
 
 
 
 
 
493
  else:
494
- # Just zero out the pixels which will be masked
495
- conditioning_image_as_pil = TF.to_pil_image(conditioning_image * (conditioning_image_mask < 0.5))
 
 
 
 
 
496
 
497
- # where mask is set to 1, set to -1 "special" masked image pixel.
498
- # -1 is outside of the 0-1 range that the controlnet normalized
499
- # input is in.
500
- conditioning_image = conditioning_image * (conditioning_image_mask < 0.5) + -1.0 * (conditioning_image_mask >= 0.5)
 
 
 
501
 
502
  return dict(conditioning_image=conditioning_image, conditioning_image_mask=conditioning_image_mask, conditioning_image_as_pil=conditioning_image_as_pil)
503
 
@@ -830,102 +855,112 @@ def sdxl_eps_theta(
830
 
831
  known_negative_prompt = "text, watermark, low-quality, signature, moiré pattern, downsampling, aliasing, distorted, blurry, glossy, blur, jpeg artifacts, compression artifacts, poorly drawn, low-resolution, bad, distortion, twisted, excessive, exaggerated pose, exaggerated limbs, grainy, symmetrical, duplicate, error, pattern, beginner, pixelated, fake, hyper, glitch, overexposed, high-contrast, bad-contrast"
832
 
 
 
833
 
834
- # TODO probably just combine with sdxl_diffusion_loop
835
- def gen_sdxl_simplified_interface(
836
- prompts: Union[str, List[str]],
837
- negative_prompts: Optional[Union[str, List[str]]] = None,
838
- controlnet_checkpoint: Optional[str] = None,
839
- controlnet: Optional[Literal["SDXLControlNet", "SDXLContolNetFull", "SDXLControlNetPreEncodedControlnetCond"]] = None,
840
- adapter_checkpoint: Optional[str] = None,
841
- num_inference_steps=50,
842
- images=None,
843
- masks=None,
844
- apply_conditioning: Optional[Literal["canny"]] = None,
845
- num_images: int = 1,
846
- guidance_scale=5.0,
847
- device: Optional[str] = None,
848
- text_encoder_one=None,
849
- text_encoder_two=None,
850
- unet=None,
851
- vae=None,
852
- ):
853
- if device is None:
854
  if torch.cuda.is_available():
855
  device = "cuda"
856
  elif torch.backends.mps.is_available():
857
  device = "mps"
858
 
859
- if text_encoder_one is None:
 
 
860
  text_encoder_one = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", variant="fp16", torch_dtype=torch.float16)
861
  text_encoder_one.to(device=device)
862
 
863
- if text_encoder_two is None:
864
  text_encoder_two = CLIPTextModelWithProjection.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2", variant="fp16", torch_dtype=torch.float16)
865
  text_encoder_two.to(device=device)
866
 
867
- if vae is None:
868
  vae = SDXLVae.load_fp16_fix(device=device)
 
869
 
870
- if unet is None:
871
  unet = SDXLUNet.load_fp16(device=device)
 
 
872
 
873
- if isinstance(controlnet, str) and controlnet_checkpoint is not None:
874
- if controlnet == "SDXLControlNet":
875
- controlnet = SDXLControlNet.load(controlnet_checkpoint, device=device, dtype=torch.float16)
876
- elif controlnet == "SDXLControlNetFull":
877
- controlnet = SDXLControlNetFull.load(controlnet_checkpoint, device=device, dtype=torch.float16)
878
- elif controlnet == "SDXLControlNetPreEncodedControlnetCond":
879
- controlnet = SDXLControlNetPreEncodedControlnetCond.load(controlnet_checkpoint, device=device, dtype=torch.float16)
880
- else:
881
- assert False
882
 
883
- if adapter_checkpoint is not None:
884
- adapter = SDXLAdapter.load(adapter_checkpoint, device=device, dtype=torch.float16)
 
 
 
 
885
  else:
886
- adapter = None
887
 
888
- sigmas = make_sigmas()
 
 
 
 
 
 
 
 
 
 
889
 
890
- timesteps = torch.linspace(0, sigmas.numel() - 1, num_inference_steps, dtype=torch.long, device=unet.device)
 
 
 
 
891
 
892
- if images is not None:
893
- if not isinstance(images, list):
894
- images = [images]
895
 
896
- if masks is not None and not isinstance(masks, list):
897
- masks = [masks]
898
 
899
- images_ = []
 
 
900
 
901
- for image_idx, image in enumerate(images):
902
- if isinstance(image, str):
903
- image = Image.open(image)
904
- image = image.convert("RGB")
905
- image = image.resize((1024, 1024))
906
- elif isinstance(image, Image.Image):
907
- ...
908
- else:
909
- assert False
910
 
911
- if apply_conditioning == "canny":
912
- import cv2
 
 
 
 
 
 
 
 
913
 
914
- image = np.array(image)
915
- image = cv2.Canny(image, 100, 200)
916
- image = image[:, :, None]
917
- controlnet_image = np.concatenate([controlnet_image, controlnet_image, controlnet_image], axis=2)
918
 
 
 
 
 
919
  image = TF.to_tensor(image)
920
 
921
- if masks is not None:
922
- mask = masks[image_idx]
923
- if isinstance(mask, str):
924
- mask = Image.open(mask)
925
- elif isinstance(mask, Image.Image):
926
- ...
927
- else:
928
- assert False
929
  mask = mask.convert("L")
930
  mask = mask.resize((1024, 1024))
931
  mask = TF.to_tensor(mask)
@@ -933,83 +968,40 @@ def gen_sdxl_simplified_interface(
933
  if isinstance(controlnet, SDXLControlNetPreEncodedControlnetCond):
934
  image = image * (mask < 0.5)
935
  image = TF.normalize(image, [0.5], [0.5])
936
- image = vae.encode(image[None, :, :, :].to(dtype=vae.dtype, device=vae.device)).to(dtype=unet.dtype, device=unet.device)
937
  mask = TF.resize(mask, (1024 // 8, 1024 // 8))[None, :, :, :].to(dtype=image.dtype, device=image.device)
938
  image = torch.concat((image, mask), dim=1)
939
  else:
940
- image = (image * (mask < 0.5) + -1.0 * (mask >= 0.5)).to(dtype=unet.dtype, device=unet.device)
941
  image = image[None, :, :, :]
942
 
943
- images_.append(image)
944
 
945
- images_ = torch.concat(images_)
946
  else:
947
- images_ = None
948
 
949
- if isinstance(prompts, str):
950
- prompts = [prompts]
951
- prompts_ = []
952
- for prompt in prompts:
953
- prompts_ += [prompt] * num_images
954
-
955
- if negative_prompts is not None:
956
- if isinstance(negative_prompts, str):
957
- negative_prompts = [negative_prompts]
958
- negative_prompts_ = []
959
- for negative_prompt in negative_prompts:
960
- negative_prompts_ += [negative_prompt] * num_images
961
  else:
962
- negative_prompts_ = None
963
 
964
- x_0 = sdxl_diffusion_loop(
965
- prompts=prompts_,
966
- negative_prompts=negative_prompts_,
967
  unet=unet,
968
  text_encoder_one=text_encoder_one,
969
  text_encoder_two=text_encoder_two,
970
- sigmas=sigmas,
971
- timesteps=timesteps,
972
  controlnet=controlnet,
973
  adapter=adapter,
974
- images=images_,
975
- guidance_scale=guidance_scale,
 
 
 
976
  )
977
 
978
- x_0 = vae.decode(x_0.to(vae.dtype))
979
- x_0 = vae.output_tensor_to_pil(x_0)
980
-
981
- return x_0
982
-
983
-
984
- if __name__ == "__main__":
985
- from argparse import ArgumentParser
986
-
987
- args = ArgumentParser()
988
- args.add_argument("--prompt", required=True, type=str)
989
- args.add_argument("--num_images", required=True, type=int, default=1)
990
- args.add_argument("--num_inference_steps", required=False, type=int, default=50)
991
- args.add_argument("--image", required=False, type=str, default=None)
992
- args.add_argument("--mask", required=False, type=str, default=None)
993
- args.add_argument("--controlnet_checkpoint", required=False, type=str, default=None)
994
- args.add_argument("--controlnet", required=False, choices=["SDXLControlNet", "SDXLControlNetFull", "SDXLControNetPreEncodedControlnetCond"], default=None)
995
- args.add_argument("--adapter_checkpoint", required=False, type=str, default=None)
996
- args.add_argument("--apply_conditioning", choices=["canny"], required=False, default=None)
997
- args.add_argument("--device", required=False, default=None)
998
- args = args.parse_args()
999
-
1000
- images = gen_sdxl_simplified_interface(
1001
- prompt=args.prompt,
1002
- num_images=args.num_images,
1003
- num_inference_steps=args.num_inference_steps,
1004
- images=[args.image],
1005
- masks=[args.mask],
1006
- controlnet_checkpoint=args.controlnet_checkpoint,
1007
- controlnet=args.controlnet,
1008
- adapter_checkpoint=args.adapter_checkpoint,
1009
- apply_conditioning=args.apply_conditioning,
1010
- device=args.device,
1011
- negative_prompt=known_negative_prompt,
1012
- )
1013
 
1014
  for i, image in enumerate(images):
1015
  image.save(f"out_{i}.png")
 
9
  import torch.nn.functional as F
10
  import torchvision.transforms
11
  import torchvision.transforms.functional as TF
 
12
  from PIL import Image
13
  from torch.nn.parallel import DistributedDataParallel as DDP
14
  from torch.utils.data import default_collate
 
347
 
348
 
349
  def get_sdxl_dataset(train_shards: str, shuffle_buffer_size: int, batch_size: int, proportion_empty_prompts: float, get_sdxl_conditioning_images=None):
350
+ import webdataset as wds
351
+
352
  dataset = (
353
  wds.WebDataset(
354
  train_shards,
 
444
  return i, j, th, tw
445
 
446
 
447
+ def get_adapter_openpose_conditioning_image(image, open_pose):
448
  resolution = image.width
449
 
450
+ conditioning_image = open_pose(image, detect_resolution=resolution, image_resolution=resolution, return_pil=False)
 
451
 
452
+ if (conditioning_image == 0).all():
453
+ return None, None
454
 
455
+ conditioning_image_as_pil = Image.fromarray(conditioning_image)
456
 
457
+ conditioning_image = TF.to_tensor(conditioning_image)
458
 
459
+ return dict(conditioning_image=conditioning_image, conditioning_image_as_pil=conditioning_image_as_pil)
 
460
 
 
 
 
 
461
 
462
+ def get_controlnet_canny_conditioning_image(image):
463
+ import cv2
464
 
465
+ conditioning_image = np.array(image)
466
+ conditioning_image = cv2.Canny(conditioning_image, 100, 200)
467
+ conditioning_image = conditioning_image[:, :, None]
468
+ conditioning_image = np.concatenate([conditioning_image, conditioning_image, conditioning_image], axis=2)
469
 
470
+ conditioning_image_as_pil = Image.fromarray(conditioning_image)
471
+
472
+ conditioning_image = TF.to_tensor(conditioning_image)
473
+
474
+ return dict(conditioning_image=conditioning_image, conditioning_image_as_pil=conditioning_image_as_pil)
475
+
476
+
477
+ def get_controlnet_pre_encoded_controlnet_inpainting_conditioning_image(image, conditioning_image_mask):
478
+ resolution = image.width
479
+
480
+ if conditioning_image_mask is None:
481
+ if random.random() <= 0.25:
482
+ conditioning_image_mask = np.ones((resolution, resolution), np.float32)
483
+ else:
484
+ conditioning_image_mask = random.choice([make_random_rectangle_mask, make_random_irregular_mask, make_outpainting_mask])(resolution, resolution)
485
 
486
+ conditioning_image_mask = torch.from_numpy(conditioning_image_mask)
487
 
488
+ conditioning_image_mask = conditioning_image_mask[None, :, :]
489
 
490
+ conditioning_image = TF.to_tensor(image)
491
 
492
+ # where mask is 1, zero out the pixels. Note that this requires mask to be concattenated
493
+ # with the mask so that the network knows the zeroed out pixels are from the mask and
494
+ # are not just zero in the original image
495
+ conditioning_image = conditioning_image * (conditioning_image_mask < 0.5)
 
496
 
497
+ conditioning_image_as_pil = TF.to_pil_image(conditioning_image)
498
 
499
+ conditioning_image = TF.normalize(conditioning_image, [0.5], [0.5])
500
+
501
+ return dict(conditioning_image=conditioning_image, conditioning_image_mask=conditioning_image_mask, conditioning_image_as_pil=conditioning_image_as_pil)
502
+
503
+
504
+ def get_controlnet_inpainting_conditioning_image(image, conditioning_image_mask):
505
+ resolution = image.width
506
+
507
+ if conditioning_image_mask is None:
508
+ if random.random() <= 0.25:
509
+ conditioning_image_mask = np.ones((resolution, resolution), np.float32)
510
  else:
511
+ conditioning_image_mask = random.choice([make_random_rectangle_mask, make_random_irregular_mask, make_outpainting_mask])(resolution, resolution)
512
+
513
+ conditioning_image_mask = torch.from_numpy(conditioning_image_mask)
514
+
515
+ conditioning_image_mask = conditioning_image_mask[None, :, :]
516
+
517
+ conditioning_image = TF.to_tensor(image)
518
 
519
+ # Just zero out the pixels which will be masked
520
+ conditioning_image_as_pil = TF.to_pil_image(conditioning_image * (conditioning_image_mask < 0.5))
521
+
522
+ # where mask is set to 1, set to -1 "special" masked image pixel.
523
+ # -1 is outside of the 0-1 range that the controlnet normalized
524
+ # input is in.
525
+ conditioning_image = conditioning_image * (conditioning_image_mask < 0.5) + -1.0 * (conditioning_image_mask >= 0.5)
526
 
527
  return dict(conditioning_image=conditioning_image, conditioning_image_mask=conditioning_image_mask, conditioning_image_as_pil=conditioning_image_as_pil)
528
 
 
855
 
856
  known_negative_prompt = "text, watermark, low-quality, signature, moiré pattern, downsampling, aliasing, distorted, blurry, glossy, blur, jpeg artifacts, compression artifacts, poorly drawn, low-resolution, bad, distortion, twisted, excessive, exaggerated pose, exaggerated limbs, grainy, symmetrical, duplicate, error, pattern, beginner, pixelated, fake, hyper, glitch, overexposed, high-contrast, bad-contrast"
857
 
858
+ if __name__ == "__main__":
859
+ from argparse import ArgumentParser
860
 
861
+ args = ArgumentParser()
862
+ args.add_argument("--prompts", required=True, type=str, nargs="+")
863
+ args.add_argument("--negative_prompts", required=False, type=str, nargs="+")
864
+ args.add_argument("--use_known_negative_prompt", action="store_true")
865
+ args.add_argument("--num_images_per_prompt", required=True, type=int, default=1)
866
+ args.add_argument("--num_inference_steps", required=False, type=int, default=50)
867
+ args.add_argument("--images", required=False, type=str, default=None, nargs="+")
868
+ args.add_argument("--masks", required=False, type=str, default=None, nargs="+")
869
+ args.add_argument("--controlnet_checkpoint", required=False, type=str, default=None)
870
+ args.add_argument("--controlnet", required=False, choices=["SDXLControlNet", "SDXLControlNetFull", "SDXLControNetPreEncodedControlnetCond"], default=None)
871
+ args.add_argument("--adapter_checkpoint", required=False, type=str, default=None)
872
+ args.add_argument("--device", required=False, default=None)
873
+ args.add_argument("--dtype", required=False, default="fp16", choices=["fp16", "fp32"])
874
+ args.add_argument("--guidance_scale", required=False, default=5.0, type=float)
875
+ args.add_argument("--seed", required=False, type=int)
876
+ args = args.parse_args()
877
+
878
+ if args.device is None:
 
 
879
  if torch.cuda.is_available():
880
  device = "cuda"
881
  elif torch.backends.mps.is_available():
882
  device = "mps"
883
 
884
+ if args.dtype == "fp16":
885
+ dtype = torch.float16
886
+
887
  text_encoder_one = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", variant="fp16", torch_dtype=torch.float16)
888
  text_encoder_one.to(device=device)
889
 
 
890
  text_encoder_two = CLIPTextModelWithProjection.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2", variant="fp16", torch_dtype=torch.float16)
891
  text_encoder_two.to(device=device)
892
 
 
893
  vae = SDXLVae.load_fp16_fix(device=device)
894
+ vae.to(torch.float16)
895
 
 
896
  unet = SDXLUNet.load_fp16(device=device)
897
+ elif args.dtype == "fp32":
898
+ dtype = torch.float32
899
 
900
+ text_encoder_one = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
901
+ text_encoder_one.to(device=device)
 
 
 
 
 
 
 
902
 
903
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2")
904
+ text_encoder_two.to(device=device)
905
+
906
+ vae = SDXLVae.load_fp16_fix(device=device)
907
+
908
+ unet = SDXLUNet.load_fp32(device=device)
909
  else:
910
+ assert False
911
 
912
+ if args.controlnet == "SDXLControlNet":
913
+ controlnet = SDXLControlNet.load(args.controlnet_checkpoint, device=device)
914
+ controlnet.to(dtype)
915
+ elif args.controlnet == "SDXLControlNetFull":
916
+ controlnet = SDXLControlNetFull.load(args.controlnet_checkpoint, device=device)
917
+ controlnet.to(dtype)
918
+ elif args.controlnet == "SDXLControlNetPreEncodedControlnetCond":
919
+ controlnet = SDXLControlNetPreEncodedControlnetCond.load(args.controlnet_checkpoint, device=device)
920
+ controlnet.to(dtype)
921
+ else:
922
+ controlnet = None
923
 
924
+ if args.adapter_checkpoint is not None:
925
+ adapter = SDXLAdapter.load(args.adapter_checkpoint, device=device)
926
+ adapter.to(dtype)
927
+ else:
928
+ adapter = None
929
 
930
+ sigmas = make_sigmas(device=device).to(unet.dtype)
 
 
931
 
932
+ timesteps = torch.linspace(0, sigmas.numel() - 1, args.num_inference_steps, dtype=torch.long, device=unet.device)
 
933
 
934
+ prompts = []
935
+ for prompt in args.prompts:
936
+ prompts += [prompt] * args.num_images_per_prompt
937
 
938
+ if args.use_known_negative_prompt:
939
+ args.negative_prompts = [known_negative_prompt]
 
 
 
 
 
 
 
940
 
941
+ if args.negative_prompts is None:
942
+ negative_prompts = None
943
+ elif len(args.negative_prompts) == 1:
944
+ negative_prompts = args.negative_prompts * len(prompts)
945
+ elif len(args.negative_prompts) == len(args.prompts):
946
+ negative_prompts = []
947
+ for negative_prompt in args.negative_prompts:
948
+ negative_prompts += [negative_prompt] * args.num_images_per_prompt
949
+ else:
950
+ assert False
951
 
952
+ if args.images is not None:
953
+ images = []
 
 
954
 
955
+ for image_idx, image in enumerate(args.images):
956
+ image = Image.open(image)
957
+ image = image.convert("RGB")
958
+ image = image.resize((1024, 1024))
959
  image = TF.to_tensor(image)
960
 
961
+ if args.masks is not None:
962
+ mask = args.masks[image_idx]
963
+ mask = Image.open(mask)
 
 
 
 
 
964
  mask = mask.convert("L")
965
  mask = mask.resize((1024, 1024))
966
  mask = TF.to_tensor(mask)
 
968
  if isinstance(controlnet, SDXLControlNetPreEncodedControlnetCond):
969
  image = image * (mask < 0.5)
970
  image = TF.normalize(image, [0.5], [0.5])
971
+ image = vae.encode(image[None, :, :, :].to(dtype=vae.dtype, device=vae.device)).to(dtype=controlnet.dtype, device=controlnet.device)
972
  mask = TF.resize(mask, (1024 // 8, 1024 // 8))[None, :, :, :].to(dtype=image.dtype, device=image.device)
973
  image = torch.concat((image, mask), dim=1)
974
  else:
975
+ image = (image * (mask < 0.5) + -1.0 * (mask >= 0.5)).to(dtype=dtype, device=device)
976
  image = image[None, :, :, :]
977
 
978
+ images += [image] * args.num_images_per_prompt
979
 
980
+ images = torch.concat(images)
981
  else:
982
+ images = None
983
 
984
+ if args.seed is None:
985
+ generator = None
 
 
 
 
 
 
 
 
 
 
986
  else:
987
+ generator = torch.Generator(device).manual_seed(args.seed)
988
 
989
+ images = sdxl_diffusion_loop(
990
+ prompts=prompts,
 
991
  unet=unet,
992
  text_encoder_one=text_encoder_one,
993
  text_encoder_two=text_encoder_two,
994
+ images=images,
 
995
  controlnet=controlnet,
996
  adapter=adapter,
997
+ sigmas=sigmas,
998
+ timesteps=timesteps,
999
+ guidance_scale=args.guidance_scale,
1000
+ negative_prompts=negative_prompts,
1001
+ generator=generator,
1002
  )
1003
 
1004
+ images = vae.output_tensor_to_pil(vae.decode(images))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1005
 
1006
  for i, image in enumerate(images):
1007
  image.save(f"out_{i}.png")
sdxl_models.py CHANGED
@@ -1,12 +1,11 @@
1
  import math
2
  import os
3
- from typing import List, Literal, Optional
4
 
5
  import safetensors.torch
6
  import torch
7
  import torch.nn.functional as F
8
  import torchvision.transforms.functional as TF
9
- import xformers.ops
10
  from PIL import Image
11
  from torch import nn
12
 
@@ -21,12 +20,14 @@ class ModelUtils:
21
  return next(self.parameters()).device
22
 
23
  @classmethod
24
- def load(cls, load_from: str, device, overrides: Optional[List[str]] = None):
25
  import load_state_dict_patch
26
 
27
  load_from = [load_from]
28
 
29
  if overrides is not None:
 
 
30
  load_from += overrides
31
 
32
  state_dict = {}
@@ -1323,51 +1324,57 @@ class TransformerDecoderBlock(nn.Module):
1323
  return hidden_states
1324
 
1325
 
1326
- class AttentionMixin:
1327
- attention_implementation: Literal["xformers", "torch_2.0_scaled_dot_product"] = "xformers"
1328
 
1329
- @classmethod
1330
- def attention(cls, to_q, to_k, to_v, to_out, head_dim, hidden_states, encoder_hidden_states=None):
1331
- batch_size, q_seq_len, channels = hidden_states.shape
1332
 
1333
- if encoder_hidden_states is not None:
1334
- kv = encoder_hidden_states
1335
- else:
1336
- kv = hidden_states
1337
 
1338
- kv_seq_len = kv.shape[1]
1339
 
1340
- query = to_q(hidden_states)
1341
- key = to_k(kv)
1342
- value = to_v(kv)
1343
 
1344
- if AttentionMixin.attention_implementation == "xformers":
1345
- query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).contiguous()
1346
- key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
1347
- value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
1348
 
1349
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value)
1350
 
1351
- hidden_states = hidden_states.to(query.dtype)
1352
- hidden_states = hidden_states.reshape(batch_size, q_seq_len, channels).contiguous()
1353
- elif AttentionMixin.attention_implementation == "torch_2.0_scaled_dot_product":
1354
- query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
1355
- key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
1356
- value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
1357
 
1358
- hidden_states = F.scaled_dot_product_attention(query, key, value)
 
1359
 
1360
- hidden_states = hidden_states.to(query.dtype)
1361
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, q_seq_len, channels).contiguous()
1362
- else:
1363
- assert False
1364
 
1365
- hidden_states = to_out(hidden_states)
1366
 
1367
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1368
 
1369
 
1370
- class Attention(nn.Module, AttentionMixin):
1371
  def __init__(self, channels, encoder_hidden_states_dim):
1372
  super().__init__()
1373
  self.to_q = nn.Linear(channels, channels, bias=False)
@@ -1376,10 +1383,10 @@ class Attention(nn.Module, AttentionMixin):
1376
  self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
1377
 
1378
  def forward(self, hidden_states, encoder_hidden_states=None):
1379
- return self.attention(self.to_q, self.to_k, self.to_v, self.to_out, 64, hidden_states, encoder_hidden_states)
1380
 
1381
 
1382
- class VaeMidBlockAttention(nn.Module, AttentionMixin):
1383
  def __init__(self, channels):
1384
  super().__init__()
1385
  self.group_norm = nn.GroupNorm(32, channels, eps=1e-06)
@@ -1397,7 +1404,7 @@ class VaeMidBlockAttention(nn.Module, AttentionMixin):
1397
 
1398
  hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1399
 
1400
- hidden_states = self.attention(self.to_q, self.to_k, self.to_v, self.to_out, self.head_dim, hidden_states)
1401
 
1402
  hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
1403
 
 
1
  import math
2
  import os
3
+ from typing import List, Literal, Optional, Union
4
 
5
  import safetensors.torch
6
  import torch
7
  import torch.nn.functional as F
8
  import torchvision.transforms.functional as TF
 
9
  from PIL import Image
10
  from torch import nn
11
 
 
20
  return next(self.parameters()).device
21
 
22
  @classmethod
23
+ def load(cls, load_from: str, device, overrides: Optional[Union[str, List[str]]] = None):
24
  import load_state_dict_patch
25
 
26
  load_from = [load_from]
27
 
28
  if overrides is not None:
29
+ if isinstance(overrides, str):
30
+ overrides = [overrides]
31
  load_from += overrides
32
 
33
  state_dict = {}
 
1324
  return hidden_states
1325
 
1326
 
1327
+ _attention_implementation: Literal["xformers", "torch_2.0_scaled_dot_product"] = "torch_2.0_scaled_dot_product"
 
1328
 
 
 
 
1329
 
1330
+ def set_attention_implementation(impl: Literal["xformers", "torch_2.0_scaled_dot_product"]):
1331
+ global _attention_implementation
1332
+ _attention_implementation = impl
 
1333
 
 
1334
 
1335
+ def attention(to_q, to_k, to_v, to_out, head_dim, hidden_states, encoder_hidden_states=None):
1336
+ batch_size, q_seq_len, channels = hidden_states.shape
 
1337
 
1338
+ if encoder_hidden_states is not None:
1339
+ kv = encoder_hidden_states
1340
+ else:
1341
+ kv = hidden_states
1342
 
1343
+ kv_seq_len = kv.shape[1]
1344
 
1345
+ query = to_q(hidden_states)
1346
+ key = to_k(kv)
1347
+ value = to_v(kv)
 
 
 
1348
 
1349
+ if _attention_implementation == "xformers":
1350
+ import xformers.ops
1351
 
1352
+ query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).contiguous()
1353
+ key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
1354
+ value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
 
1355
 
1356
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value)
1357
 
1358
+ hidden_states = hidden_states.to(query.dtype)
1359
+ hidden_states = hidden_states.reshape(batch_size, q_seq_len, channels).contiguous()
1360
+ elif _attention_implementation == "torch_2.0_scaled_dot_product":
1361
+ query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
1362
+ key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
1363
+ value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
1364
+
1365
+ hidden_states = F.scaled_dot_product_attention(query, key, value)
1366
+
1367
+ hidden_states = hidden_states.to(query.dtype)
1368
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, q_seq_len, channels).contiguous()
1369
+ else:
1370
+ assert False
1371
+
1372
+ hidden_states = to_out(hidden_states)
1373
+
1374
+ return hidden_states
1375
 
1376
 
1377
+ class Attention(nn.Module):
1378
  def __init__(self, channels, encoder_hidden_states_dim):
1379
  super().__init__()
1380
  self.to_q = nn.Linear(channels, channels, bias=False)
 
1383
  self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
1384
 
1385
  def forward(self, hidden_states, encoder_hidden_states=None):
1386
+ return attention(self.to_q, self.to_k, self.to_v, self.to_out, 64, hidden_states, encoder_hidden_states)
1387
 
1388
 
1389
+ class VaeMidBlockAttention(nn.Module):
1390
  def __init__(self, channels):
1391
  super().__init__()
1392
  self.group_norm = nn.GroupNorm(32, channels, eps=1e-06)
 
1404
 
1405
  hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1406
 
1407
+ hidden_states = attention(self.to_q, self.to_k, self.to_v, self.to_out, self.head_dim, hidden_states)
1408
 
1409
  hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
1410