williamberman commited on
Commit
f280910
1 Parent(s): aa67e5e
Files changed (4) hide show
  1. app.py +30 -22
  2. diffusion.py +5 -5
  3. sdxl.py +36 -16
  4. sdxl_models.py +53 -39
app.py CHANGED
@@ -1,23 +1,22 @@
1
  import gradio as gr
2
  import torch
3
 
4
- from diffusers import AutoPipelineForInpainting, StableDiffusionXLPipeline
5
  import diffusers
6
  from share_btn import community_icon_html, loading_icon_html, share_js
7
- from sdxl import gen_sdxl_simplified_interface
8
  from sdxl_models import SDXLUNet, SDXLVae, SDXLControlNetPreEncodedControlnetCond
 
 
 
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to(device)
12
 
13
- # TODO - just download individual files
14
- # StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", variant="fp16") # download weights
15
- comparing_unet = SDXLUNet.load("/admin/home/william/.cache/huggingface/hub/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/76d28af79639c28a79fa5c6c6468febd3490a37e/unet/diffusion_pytorch_model.fp16.safetensors", device=device)
16
- # comparing_vae = SDXLVae.load("/admin/home/william/.cache/huggingface/hub/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/76d28af79639c28a79fa5c6c6468febd3490a37e/vae/diffusion_pytorch_model.fp16.safetensors", device=device)
17
- comparing_vae = SDXLVae.load("/admin/home/william/.cache/huggingface/hub/models--madebyollin--sdxl-vae-fp16-fix/snapshots/4df413ca49271c25289a6482ab97a433f8117d15/diffusion_pytorch_model.safetensors", device=device)
18
  comparing_vae.to(torch.float16)
19
- # comparing_controlnet = SDXLControlNetPreEncodedControlnetCond.load("/fsx/william/diffusers-utils/output/sdxl_controlnet_inpaint_pre_encoded_controlnet_cond/checkpoint-200000/controlnet/diffusion_pytorch_model.safetensors", device="cuda") # TODO - upload checkpoint
20
- comparing_controlnet = SDXLControlNetPreEncodedControlnetCond.load("./controlnet_vae.safetensors", device="cuda") # TODO - upload checkpoint
21
  comparing_controlnet.to(torch.float16)
22
 
23
  def read_content(file_path: str) -> str:
@@ -45,15 +44,26 @@ def predict(dict, prompt="", negative_prompt="", guidance_scale=7.5, steps=20, s
45
  init_image = dict["image"].convert("RGB").resize((1024, 1024))
46
  mask = dict["mask"].convert("RGB").resize((1024, 1024))
47
 
48
- # output = pipe(prompt = prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength)
49
- output_controlnet_vae_encoding = gen_sdxl_simplified_interface(
50
- prompts=prompt, negative_prompts=negative_prompt, images=init_image, masks=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps),
51
- text_encoder_one=pipe.text_encoder, text_encoder_two=pipe.text_encoder_2, unet=comparing_unet, vae=comparing_vae, controlnet=comparing_controlnet, device=device
 
 
 
 
 
 
 
 
 
 
 
 
52
  )
 
53
 
54
- # return output.images[0], output_controlnet_vae_encoding[0], gr.update(visible=True)
55
-
56
- return output_controlnet_vae_encoding[0], gr.update(visible=True)
57
 
58
 
59
  css = '''
@@ -107,7 +117,7 @@ with image_blocks as demo:
107
  with gr.Accordion(label="Advanced Settings", open=False):
108
  with gr.Row(mobile_collapse=False, equal_height=True):
109
  guidance_scale = gr.Number(value=7.5, minimum=1.0, maximum=20.0, step=0.1, label="guidance_scale")
110
- steps = gr.Number(value=20, minimum=10, maximum=30, step=1, label="steps")
111
  strength = gr.Number(value=0.99, minimum=0.01, maximum=1.0, step=0.01, label="strength")
112
  negative_prompt = gr.Textbox(label="negative_prompt", placeholder="Your negative prompt", info="what you don't want to see in the image")
113
  with gr.Row(mobile_collapse=False, equal_height=True):
@@ -123,10 +133,8 @@ with image_blocks as demo:
123
  share_button = gr.Button("Share to community", elem_id="share-btn",visible=True)
124
 
125
 
126
- # btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container], api_name='run')
127
- # prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container])
128
- btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out_comparing, share_btn_container], api_name='run')
129
- prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out_comparing, share_btn_container])
130
  share_button.click(None, [], [], _js=share_js)
131
 
132
  gr.Examples(
@@ -155,4 +163,4 @@ with image_blocks as demo:
155
  """
156
  )
157
 
158
- image_blocks.queue(max_size=25).launch()
 
1
  import gradio as gr
2
  import torch
3
 
4
+ from diffusers import AutoPipelineForInpainting
5
  import diffusers
6
  from share_btn import community_icon_html, loading_icon_html, share_js
7
+ from sdxl import sdxl_diffusion_loop
8
  from sdxl_models import SDXLUNet, SDXLVae, SDXLControlNetPreEncodedControlnetCond
9
+ import torchvision.transforms.functional as TF
10
+ from diffusion import make_sigmas
11
+ from huggingface_hub import hf_hub_download
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to(device)
15
 
16
+ comparing_unet = SDXLUNet.load(hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "unet/diffusion_pytorch_model.fp16.safetensors"), device=device)
17
+ comparing_vae = SDXLVae.load(hf_hub_download("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors"), device=device)
 
 
 
18
  comparing_vae.to(torch.float16)
19
+ comparing_controlnet = SDXLControlNetPreEncodedControlnetCond.load(hf_hub_download("williamberman/sdxl_controlnet_inpainting", "sdxl_controlnet_inpaint_pre_encoded_controlnet_cond_checkpoint_200000.safetensors"), device=device)
 
20
  comparing_controlnet.to(torch.float16)
21
 
22
  def read_content(file_path: str) -> str:
 
44
  init_image = dict["image"].convert("RGB").resize((1024, 1024))
45
  mask = dict["mask"].convert("RGB").resize((1024, 1024))
46
 
47
+ output = pipe(prompt = prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength)
48
+
49
+ image = TF.to_tensor(dict["image"].convert("RGB").resize((1024, 1024)))
50
+ mask = TF.to_tensor(dict["mask"].convert("L").resize((1024, 1024)))
51
+ image = image * (mask < 0.5)
52
+ image = TF.normalize(image, [0.5], [0.5])
53
+ image = comparing_vae.encode(image[None, :, :, :].to(dtype=comparing_vae.dtype, device=comparing_vae.device)).to(dtype=comparing_controlnet.dtype, device=comparing_controlnet.device)
54
+ mask = TF.resize(mask, (1024 // 8, 1024 // 8))[None, :, :, :].to(dtype=image.dtype, device=image.device)
55
+ image = torch.concat((image, mask), dim=1)
56
+
57
+ sigmas = make_sigmas(device=comparing_unet.device).to(dtype=comparing_unet.dtype)
58
+ timesteps = torch.linspace(0, sigmas.numel() - 1, int(steps), dtype=torch.long, device=comparing_unet.device)
59
+
60
+ out = sdxl_diffusion_loop(
61
+ prompts=prompt, negative_prompts=negative_prompt, images=image, guidance_scale=guidance_scale, sigmas=sigmas, timesteps=timesteps,
62
+ text_encoder_one=pipe.text_encoder, text_encoder_two=pipe.text_encoder_2, unet=comparing_unet, controlnet=comparing_controlnet
63
  )
64
+ out = comparing_vae.output_tensor_to_pil(comparing_vae.decode(out))
65
 
66
+ return output.images[0], out[0], gr.update(visible=True)
 
 
67
 
68
 
69
  css = '''
 
117
  with gr.Accordion(label="Advanced Settings", open=False):
118
  with gr.Row(mobile_collapse=False, equal_height=True):
119
  guidance_scale = gr.Number(value=7.5, minimum=1.0, maximum=20.0, step=0.1, label="guidance_scale")
120
+ steps = gr.Number(value=20, minimum=1, maximum=1000, step=1, label="steps")
121
  strength = gr.Number(value=0.99, minimum=0.01, maximum=1.0, step=0.01, label="strength")
122
  negative_prompt = gr.Textbox(label="negative_prompt", placeholder="Your negative prompt", info="what you don't want to see in the image")
123
  with gr.Row(mobile_collapse=False, equal_height=True):
 
133
  share_button = gr.Button("Share to community", elem_id="share-btn",visible=True)
134
 
135
 
136
+ btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container], api_name='run')
137
+ prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container])
 
 
138
  share_button.click(None, [], [], _js=share_js)
139
 
140
  gr.Examples(
 
163
  """
164
  )
165
 
166
+ image_blocks.queue(max_size=25).launch(share=True)
diffusion.py CHANGED
@@ -21,15 +21,14 @@ def rk_ode_solver_diffusion_loop(eps_theta, timesteps, sigmas, x_T, rk_steps_wei
21
  x_t = x_T
22
 
23
  for i in range(len(timesteps) - 1, -1, -1):
24
- t = timesteps[i]
25
-
26
- sigma = sigmas[i]
27
 
28
  if i == 0:
29
  eps_hat = eps_theta(x_t=x_t, t=t, sigma=sigma)
30
  x_0_hat = x_t - sigma * eps_hat
31
  else:
32
- dt = sigmas[i - 1] - sigma
33
 
34
  dx_by_dt = torch.zeros_like(x_t)
35
  dx_by_dt_cur = torch.zeros_like(x_t)
@@ -41,7 +40,8 @@ def rk_ode_solver_diffusion_loop(eps_theta, timesteps, sigmas, x_T, rk_steps_wei
41
  eps_hat = eps_theta(x_t=x_t_, t=t_, sigma=sigma)
42
  # TODO - note which specific ode this is the solution to and
43
  # how input scaling does/doesn't effect the solution
44
- dx_by_dt_cur = (x_t_ - sigma * eps_hat) / sigma
 
45
  dx_by_dt += dx_by_dt_cur * rk_weight
46
 
47
  x_t_minus_1 = x_t + dx_by_dt * dt
 
21
  x_t = x_T
22
 
23
  for i in range(len(timesteps) - 1, -1, -1):
24
+ t = timesteps[i].unsqueeze(0)
25
+ sigma = sigmas[t]
 
26
 
27
  if i == 0:
28
  eps_hat = eps_theta(x_t=x_t, t=t, sigma=sigma)
29
  x_0_hat = x_t - sigma * eps_hat
30
  else:
31
+ dt = sigmas[timesteps[i - 1]] - sigma
32
 
33
  dx_by_dt = torch.zeros_like(x_t)
34
  dx_by_dt_cur = torch.zeros_like(x_t)
 
40
  eps_hat = eps_theta(x_t=x_t_, t=t_, sigma=sigma)
41
  # TODO - note which specific ode this is the solution to and
42
  # how input scaling does/doesn't effect the solution
43
+ # dx_by_dt_cur = (x_t_ - sigma * eps_hat) / sigma
44
+ dx_by_dt_cur = eps_hat
45
  dx_by_dt += dx_by_dt_cur * rk_weight
46
 
47
  x_t_minus_1 = x_t + dx_by_dt * dt
sdxl.py CHANGED
@@ -667,7 +667,7 @@ def apply_padding(mask, coord):
667
 
668
  @torch.no_grad()
669
  def sdxl_diffusion_loop(
670
- prompts: List[str],
671
  unet,
672
  text_encoder_one,
673
  text_encoder_two,
@@ -683,12 +683,13 @@ def sdxl_diffusion_loop(
683
  negative_prompts=None,
684
  diffusion_loop=euler_ode_solver_diffusion_loop,
685
  ):
686
- batch_size = len(prompts)
 
687
 
688
- if negative_prompts is None:
689
- negative_prompts = [""] * batch_size
690
 
691
- prompts += negative_prompts
 
692
 
693
  encoder_hidden_states, pooled_encoder_hidden_states = sdxl_text_conditioning(
694
  text_encoder_one,
@@ -699,15 +700,26 @@ def sdxl_diffusion_loop(
699
  encoder_hidden_states = encoder_hidden_states.to(unet.dtype)
700
  pooled_encoder_hidden_states = pooled_encoder_hidden_states.to(unet.dtype)
701
 
 
 
 
 
 
 
 
 
 
 
 
702
  if sigmas is None:
703
  sigmas = make_sigmas(device=unet.device)
704
 
 
 
 
705
  if x_T is None:
706
  x_T = torch.randn((batch_size, 4, 1024 // 8, 1024 // 8), dtype=unet.dtype, device=unet.device, generator=generator)
707
- x_T = x_T * ((sigmas.max() ** 2 + 1) ** 0.5)
708
-
709
- if timesteps is None:
710
- timesteps = torch.linspace(0, sigmas.numel(), 50, dtype=torch.long, device=unet.device)
711
 
712
  if micro_conditioning is None:
713
  micro_conditioning = torch.tensor([[1024, 1024, 0, 0, 1024, 1024]], dtype=torch.long, device=unet.device)
@@ -723,13 +735,14 @@ def sdxl_diffusion_loop(
723
  else:
724
  controlnet_cond = None
725
 
726
- eps_theta = lambda x_t, t, sigma: sdxl_eps_theta(
727
- x_t=x_t,
728
- t=t,
729
- sigma=sigma,
730
  unet=unet,
731
  encoder_hidden_states=encoder_hidden_states,
732
  pooled_encoder_hidden_states=pooled_encoder_hidden_states,
 
 
733
  micro_conditioning=micro_conditioning,
734
  guidance_scale=guidance_scale,
735
  controlnet=controlnet,
@@ -750,6 +763,8 @@ def sdxl_eps_theta(
750
  unet,
751
  encoder_hidden_states,
752
  pooled_encoder_hidden_states,
 
 
753
  micro_conditioning,
754
  guidance_scale,
755
  controlnet=None,
@@ -761,13 +776,18 @@ def sdxl_eps_theta(
761
 
762
  if guidance_scale > 1.0:
763
  scaled_x_t = torch.concat([scaled_x_t, scaled_x_t])
 
 
 
 
764
  micro_conditioning = torch.concat([micro_conditioning, micro_conditioning])
 
765
  if controlnet_cond is not None:
766
  controlnet_cond = torch.concat([controlnet_cond, controlnet_cond])
767
 
768
  if controlnet is not None:
769
  controlnet_out = controlnet(
770
- x_t=scaled_x_t,
771
  t=t,
772
  encoder_hidden_states=encoder_hidden_states.to(controlnet.dtype),
773
  micro_conditioning=micro_conditioning.to(controlnet.dtype),
@@ -801,7 +821,7 @@ def sdxl_eps_theta(
801
  )
802
 
803
  if guidance_scale > 1.0:
804
- eps_hat_uncond, eps_hat = eps_hat.chunk(2)
805
 
806
  eps_hat = eps_hat_uncond + guidance_scale * (eps_hat - eps_hat_uncond)
807
 
@@ -867,7 +887,7 @@ def gen_sdxl_simplified_interface(
867
 
868
  sigmas = make_sigmas()
869
 
870
- timesteps = torch.linspace(0, sigmas.numel(), num_inference_steps, dtype=torch.long, device=unet.device)
871
 
872
  if images is not None:
873
  if not isinstance(images, list):
 
667
 
668
  @torch.no_grad()
669
  def sdxl_diffusion_loop(
670
+ prompts: Union[str, List[str]],
671
  unet,
672
  text_encoder_one,
673
  text_encoder_two,
 
683
  negative_prompts=None,
684
  diffusion_loop=euler_ode_solver_diffusion_loop,
685
  ):
686
+ if isinstance(prompts, str):
687
+ prompts = [prompts]
688
 
689
+ batch_size = len(prompts)
 
690
 
691
+ if negative_prompts is not None and guidance_scale > 1.0:
692
+ prompts += negative_prompts
693
 
694
  encoder_hidden_states, pooled_encoder_hidden_states = sdxl_text_conditioning(
695
  text_encoder_one,
 
700
  encoder_hidden_states = encoder_hidden_states.to(unet.dtype)
701
  pooled_encoder_hidden_states = pooled_encoder_hidden_states.to(unet.dtype)
702
 
703
+ if guidance_scale > 1.0:
704
+ if negative_prompts is None:
705
+ negative_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
706
+ negative_pooled_encoder_hidden_states = torch.zeros_like(pooled_encoder_hidden_states)
707
+ else:
708
+ encoder_hidden_states, negative_encoder_hidden_states = torch.chunk(encoder_hidden_states, 2)
709
+ pooled_encoder_hidden_states, negative_pooled_encoder_hidden_states = torch.chunk(pooled_encoder_hidden_states, 2)
710
+ else:
711
+ negative_encoder_hidden_states = None
712
+ negative_pooled_encoder_hidden_states = None
713
+
714
  if sigmas is None:
715
  sigmas = make_sigmas(device=unet.device)
716
 
717
+ if timesteps is None:
718
+ timesteps = torch.linspace(0, sigmas.numel() - 1, 50, dtype=torch.long, device=unet.device)
719
+
720
  if x_T is None:
721
  x_T = torch.randn((batch_size, 4, 1024 // 8, 1024 // 8), dtype=unet.dtype, device=unet.device, generator=generator)
722
+ x_T = x_T * ((sigmas[timesteps[-1]] ** 2 + 1) ** 0.5)
 
 
 
723
 
724
  if micro_conditioning is None:
725
  micro_conditioning = torch.tensor([[1024, 1024, 0, 0, 1024, 1024]], dtype=torch.long, device=unet.device)
 
735
  else:
736
  controlnet_cond = None
737
 
738
+ eps_theta = lambda *args, **kwargs: sdxl_eps_theta(
739
+ *args,
740
+ **kwargs,
 
741
  unet=unet,
742
  encoder_hidden_states=encoder_hidden_states,
743
  pooled_encoder_hidden_states=pooled_encoder_hidden_states,
744
+ negative_encoder_hidden_states=negative_encoder_hidden_states,
745
+ negative_pooled_encoder_hidden_states=negative_pooled_encoder_hidden_states,
746
  micro_conditioning=micro_conditioning,
747
  guidance_scale=guidance_scale,
748
  controlnet=controlnet,
 
763
  unet,
764
  encoder_hidden_states,
765
  pooled_encoder_hidden_states,
766
+ negative_encoder_hidden_states,
767
+ negative_pooled_encoder_hidden_states,
768
  micro_conditioning,
769
  guidance_scale,
770
  controlnet=None,
 
776
 
777
  if guidance_scale > 1.0:
778
  scaled_x_t = torch.concat([scaled_x_t, scaled_x_t])
779
+
780
+ encoder_hidden_states = torch.concat((encoder_hidden_states, negative_encoder_hidden_states))
781
+ pooled_encoder_hidden_states = torch.concat((pooled_encoder_hidden_states, negative_pooled_encoder_hidden_states))
782
+
783
  micro_conditioning = torch.concat([micro_conditioning, micro_conditioning])
784
+
785
  if controlnet_cond is not None:
786
  controlnet_cond = torch.concat([controlnet_cond, controlnet_cond])
787
 
788
  if controlnet is not None:
789
  controlnet_out = controlnet(
790
+ x_t=scaled_x_t.to(controlnet.dtype),
791
  t=t,
792
  encoder_hidden_states=encoder_hidden_states.to(controlnet.dtype),
793
  micro_conditioning=micro_conditioning.to(controlnet.dtype),
 
821
  )
822
 
823
  if guidance_scale > 1.0:
824
+ eps_hat, eps_hat_uncond = eps_hat.chunk(2)
825
 
826
  eps_hat = eps_hat_uncond + guidance_scale * (eps_hat - eps_hat_uncond)
827
 
 
887
 
888
  sigmas = make_sigmas()
889
 
890
+ timesteps = torch.linspace(0, sigmas.numel() - 1, num_inference_steps, dtype=torch.long, device=unet.device)
891
 
892
  if images is not None:
893
  if not isinstance(images, list):
sdxl_models.py CHANGED
@@ -1,6 +1,6 @@
1
  import math
2
  import os
3
- from typing import List, Optional
4
 
5
  import safetensors.torch
6
  import torch
@@ -1246,16 +1246,14 @@ class ResnetBlock2D(nn.Module):
1246
  def forward(self, hidden_states, temb=None):
1247
  residual = hidden_states
1248
 
1249
- if self.time_emb_proj is not None:
1250
- assert temb is not None
1251
- temb = self.nonlinearity(temb)
1252
- temb = self.time_emb_proj(temb)[:, :, None, None]
1253
-
1254
  hidden_states = self.norm1(hidden_states)
1255
  hidden_states = self.nonlinearity(hidden_states)
1256
  hidden_states = self.conv1(hidden_states)
1257
 
1258
- if temb is not None:
 
 
 
1259
  hidden_states = hidden_states + temb
1260
 
1261
  hidden_states = self.norm2(hidden_states)
@@ -1325,7 +1323,51 @@ class TransformerDecoderBlock(nn.Module):
1325
  return hidden_states
1326
 
1327
 
1328
- class Attention(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1329
  def __init__(self, channels, encoder_hidden_states_dim):
1330
  super().__init__()
1331
  self.to_q = nn.Linear(channels, channels, bias=False)
@@ -1334,10 +1376,10 @@ class Attention(nn.Module):
1334
  self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
1335
 
1336
  def forward(self, hidden_states, encoder_hidden_states=None):
1337
- return attention(self.to_q, self.to_k, self.to_v, self.to_out, 64, hidden_states, encoder_hidden_states)
1338
 
1339
 
1340
- class VaeMidBlockAttention(nn.Module):
1341
  def __init__(self, channels):
1342
  super().__init__()
1343
  self.group_norm = nn.GroupNorm(32, channels, eps=1e-06)
@@ -1355,7 +1397,7 @@ class VaeMidBlockAttention(nn.Module):
1355
 
1356
  hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1357
 
1358
- hidden_states = attention(self.to_q, self.to_k, self.to_v, self.to_out, self.head_dim, hidden_states)
1359
 
1360
  hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
1361
 
@@ -1364,34 +1406,6 @@ class VaeMidBlockAttention(nn.Module):
1364
  return hidden_states
1365
 
1366
 
1367
- def attention(to_q, to_k, to_v, to_out, head_dim, hidden_states, encoder_hidden_states=None):
1368
- batch_size, q_seq_len, channels = hidden_states.shape
1369
-
1370
- if encoder_hidden_states is not None:
1371
- kv = encoder_hidden_states
1372
- else:
1373
- kv = hidden_states
1374
-
1375
- kv_seq_len = kv.shape[1]
1376
-
1377
- query = to_q(hidden_states)
1378
- key = to_k(kv)
1379
- value = to_v(kv)
1380
-
1381
- query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).contiguous()
1382
- key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
1383
- value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
1384
-
1385
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value)
1386
-
1387
- hidden_states = hidden_states.to(query.dtype)
1388
- hidden_states = hidden_states.reshape(batch_size, q_seq_len, channels).contiguous()
1389
-
1390
- hidden_states = to_out(hidden_states)
1391
-
1392
- return hidden_states
1393
-
1394
-
1395
  class GEGLU(nn.Module):
1396
  def __init__(self, dim_in: int, dim_out: int):
1397
  super().__init__()
 
1
  import math
2
  import os
3
+ from typing import List, Literal, Optional
4
 
5
  import safetensors.torch
6
  import torch
 
1246
  def forward(self, hidden_states, temb=None):
1247
  residual = hidden_states
1248
 
 
 
 
 
 
1249
  hidden_states = self.norm1(hidden_states)
1250
  hidden_states = self.nonlinearity(hidden_states)
1251
  hidden_states = self.conv1(hidden_states)
1252
 
1253
+ if self.time_emb_proj is not None:
1254
+ assert temb is not None
1255
+ temb = self.nonlinearity(temb)
1256
+ temb = self.time_emb_proj(temb)[:, :, None, None]
1257
  hidden_states = hidden_states + temb
1258
 
1259
  hidden_states = self.norm2(hidden_states)
 
1323
  return hidden_states
1324
 
1325
 
1326
+ class AttentionMixin:
1327
+ attention_implementation: Literal["xformers", "torch_2.0_scaled_dot_product"] = "xformers"
1328
+
1329
+ @classmethod
1330
+ def attention(cls, to_q, to_k, to_v, to_out, head_dim, hidden_states, encoder_hidden_states=None):
1331
+ batch_size, q_seq_len, channels = hidden_states.shape
1332
+
1333
+ if encoder_hidden_states is not None:
1334
+ kv = encoder_hidden_states
1335
+ else:
1336
+ kv = hidden_states
1337
+
1338
+ kv_seq_len = kv.shape[1]
1339
+
1340
+ query = to_q(hidden_states)
1341
+ key = to_k(kv)
1342
+ value = to_v(kv)
1343
+
1344
+ if AttentionMixin.attention_implementation == "xformers":
1345
+ query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).contiguous()
1346
+ key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
1347
+ value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
1348
+
1349
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value)
1350
+
1351
+ hidden_states = hidden_states.to(query.dtype)
1352
+ hidden_states = hidden_states.reshape(batch_size, q_seq_len, channels).contiguous()
1353
+ elif AttentionMixin.attention_implementation == "torch_2.0_scaled_dot_product":
1354
+ query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
1355
+ key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
1356
+ value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous()
1357
+
1358
+ hidden_states = F.scaled_dot_product_attention(query, key, value)
1359
+
1360
+ hidden_states = hidden_states.to(query.dtype)
1361
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, q_seq_len, channels).contiguous()
1362
+ else:
1363
+ assert False
1364
+
1365
+ hidden_states = to_out(hidden_states)
1366
+
1367
+ return hidden_states
1368
+
1369
+
1370
+ class Attention(nn.Module, AttentionMixin):
1371
  def __init__(self, channels, encoder_hidden_states_dim):
1372
  super().__init__()
1373
  self.to_q = nn.Linear(channels, channels, bias=False)
 
1376
  self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
1377
 
1378
  def forward(self, hidden_states, encoder_hidden_states=None):
1379
+ return self.attention(self.to_q, self.to_k, self.to_v, self.to_out, 64, hidden_states, encoder_hidden_states)
1380
 
1381
 
1382
+ class VaeMidBlockAttention(nn.Module, AttentionMixin):
1383
  def __init__(self, channels):
1384
  super().__init__()
1385
  self.group_norm = nn.GroupNorm(32, channels, eps=1e-06)
 
1397
 
1398
  hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1399
 
1400
+ hidden_states = self.attention(self.to_q, self.to_k, self.to_v, self.to_out, self.head_dim, hidden_states)
1401
 
1402
  hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
1403
 
 
1406
  return hidden_states
1407
 
1408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1409
  class GEGLU(nn.Module):
1410
  def __init__(self, dim_in: int, dim_out: int):
1411
  super().__init__()