williamberman commited on
Commit
f0e6b7a
1 Parent(s): 3e48ac3

init comparison

Browse files
Files changed (5) hide show
  1. app.py +16 -4
  2. diffusion.py +58 -0
  3. load_state_dict_patch.py +415 -0
  4. sdxl.py +962 -0
  5. sdxl_models.py +1375 -0
app.py CHANGED
@@ -1,13 +1,20 @@
1
  import gradio as gr
2
  import torch
3
 
4
- from diffusers import AutoPipelineForInpainting, UNet2DConditionModel
5
  import diffusers
6
  from share_btn import community_icon_html, loading_icon_html, share_js
 
 
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
  pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to(device)
10
 
 
 
 
 
 
11
  def read_content(file_path: str) -> str:
12
  """read the content of target file
13
  """
@@ -34,8 +41,12 @@ def predict(dict, prompt="", negative_prompt="", guidance_scale=7.5, steps=20, s
34
  mask = dict["mask"].convert("RGB").resize((1024, 1024))
35
 
36
  output = pipe(prompt = prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength)
 
 
 
 
37
 
38
- return output.images[0], gr.update(visible=True)
39
 
40
 
41
  css = '''
@@ -98,14 +109,15 @@ with image_blocks as demo:
98
 
99
  with gr.Column():
100
  image_out = gr.Image(label="Output", elem_id="output-img", height=400)
 
101
  with gr.Group(elem_id="share-btn-container", visible=False) as share_btn_container:
102
  community_icon = gr.HTML(community_icon_html)
103
  loading_icon = gr.HTML(loading_icon_html)
104
  share_button = gr.Button("Share to community", elem_id="share-btn",visible=True)
105
 
106
 
107
- btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, share_btn_container], api_name='run')
108
- prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, share_btn_container])
109
  share_button.click(None, [], [], _js=share_js)
110
 
111
  gr.Examples(
 
1
  import gradio as gr
2
  import torch
3
 
4
+ from diffusers import AutoPipelineForInpainting
5
  import diffusers
6
  from share_btn import community_icon_html, loading_icon_html, share_js
7
+ from sdxl import gen_sdxl_simplified_interface
8
+ from sdxl_models import SDXLUNet, SDXLVae, SDXLControlNetPreEncodedControlnetCond
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to(device)
12
 
13
+ comparing_unet = SDXLUNet.load_fp16(device=device)
14
+ comparing_vae = SDXLVae.load_fp16_fix(device=device)
15
+ comparing_controlnet = SDXLControlNetPreEncodedControlnetCond.load("", device="cuda") # TODO - upload checkpoint
16
+ comparing_controlnet.to(torch.float16)
17
+
18
  def read_content(file_path: str) -> str:
19
  """read the content of target file
20
  """
 
41
  mask = dict["mask"].convert("RGB").resize((1024, 1024))
42
 
43
  output = pipe(prompt = prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength)
44
+ output_controlnet_vae_encoding = gen_sdxl_simplified_interface(
45
+ prompt=prompt, negative_prompt=negative_prompt, images=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps),
46
+ text_encoder_one=pipe.text_encoder, text_encoder_two=pipe.text_encoder_2, unet=comparing_unet, vae=comparing_vae, controlnet=comparing_controlnet, device=device
47
+ )
48
 
49
+ return output.images[0], output_controlnet_vae_encoding[0], gr.update(visible=True)
50
 
51
 
52
  css = '''
 
109
 
110
  with gr.Column():
111
  image_out = gr.Image(label="Output", elem_id="output-img", height=400)
112
+ image_out_comparing = gr.Image(label="Output", elem_id="output-img-comparing", height=400)
113
  with gr.Group(elem_id="share-btn-container", visible=False) as share_btn_container:
114
  community_icon = gr.HTML(community_icon_html)
115
  loading_icon = gr.HTML(loading_icon_html)
116
  share_button = gr.Button("Share to community", elem_id="share-btn",visible=True)
117
 
118
 
119
+ btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container], api_name='run')
120
+ prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container])
121
  share_button.click(None, [], [], _js=share_js)
122
 
123
  gr.Examples(
diffusion.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ default_num_train_timesteps = 1000
4
+
5
+
6
+ @torch.no_grad()
7
+ def make_sigmas(beta_start=0.00085, beta_end=0.012, num_train_timesteps=default_num_train_timesteps, device=None):
8
+ betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32, device=device) ** 2
9
+
10
+ alphas = 1.0 - betas
11
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
12
+
13
+ # TODO - would be nice to use a direct expression for this
14
+ sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
15
+
16
+ return sigmas
17
+
18
+
19
+ @torch.no_grad()
20
+ def rk_ode_solver_diffusion_loop(eps_theta, timesteps, sigmas, x_T, rk_steps_weights):
21
+ x_t = x_T
22
+
23
+ for i in range(len(timesteps) - 1, -1, -1):
24
+ t = timesteps[i]
25
+
26
+ sigma = sigmas[i]
27
+
28
+ if i == 0:
29
+ eps_hat = eps_theta(x_t=x_t, t=t, sigma=sigma)
30
+ x_0_hat = x_t - sigma * eps_hat
31
+ else:
32
+ dt = sigmas[i - 1] - sigma
33
+
34
+ dx_by_dt = torch.zeros_like(x_t)
35
+ dx_by_dt_cur = torch.zeros_like(x_t)
36
+
37
+ for rk_step, rk_weight in rk_steps_weights:
38
+ dt_ = dt * rk_step
39
+ t_ = t + dt_
40
+ x_t_ = x_t + dx_by_dt_cur * dt_
41
+ eps_hat = eps_theta(x_t=x_t_, t=t_, sigma=sigma)
42
+ # TODO - note which specific ode this is the solution to and
43
+ # how input scaling does/doesn't effect the solution
44
+ dx_by_dt_cur = (x_t_ - sigma * eps_hat) / sigma
45
+ dx_by_dt += dx_by_dt_cur * rk_weight
46
+
47
+ x_t_minus_1 = x_t + dx_by_dt * dt
48
+
49
+ x_t = x_t_minus_1
50
+
51
+ return x_0_hat
52
+
53
+
54
+ euler_ode_solver_diffusion_loop = lambda *args, **kwargs: rk_ode_solver_diffusion_loop(*args, **kwargs, rk_steps_weights=[[0, 1]])
55
+
56
+ heun_ode_solver_diffusion_loop = lambda *args, **kwargs: rk_ode_solver_diffusion_loop(*args, **kwargs, rk_steps_weights=[[0, 0.5], [1, 0.5]])
57
+
58
+ rk4_ode_solver_diffusion_loop = lambda *args, **kwargs: rk_ode_solver_diffusion_loop(*args, **kwargs, rk_steps_weights=[[0, 1 / 6], [1 / 2, 1 / 3], [1 / 2, 1 / 3], [1, 1 / 6]])
load_state_dict_patch.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from collections import OrderedDict
3
+ from typing import Any, List, Mapping
4
+
5
+ import torch
6
+ from torch.nn import Module
7
+ from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
8
+
9
+ # fmt: off
10
+
11
+ # this patch is for adding the `assign` key to load_state_dict.
12
+ # the code is in pytorch source for version 2.1
13
+
14
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
15
+ missing_keys, unexpected_keys, error_msgs):
16
+ r"""Copies parameters and buffers from :attr:`state_dict` into only
17
+ this module, but not its descendants. This is called on every submodule
18
+ in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
19
+ module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
20
+ For state dicts without metadata, :attr:`local_metadata` is empty.
21
+ Subclasses can achieve class-specific backward compatible loading using
22
+ the version number at `local_metadata.get("version", None)`.
23
+ Additionally, :attr:`local_metadata` can also contain the key
24
+ `assign_to_params_buffers` that indicates whether keys should be
25
+ assigned their corresponding tensor in the state_dict.
26
+
27
+ .. note::
28
+ :attr:`state_dict` is not the same object as the input
29
+ :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
30
+ it can be modified.
31
+
32
+ Args:
33
+ state_dict (dict): a dict containing parameters and
34
+ persistent buffers.
35
+ prefix (str): the prefix for parameters and buffers used in this
36
+ module
37
+ local_metadata (dict): a dict containing the metadata for this module.
38
+ See
39
+ strict (bool): whether to strictly enforce that the keys in
40
+ :attr:`state_dict` with :attr:`prefix` match the names of
41
+ parameters and buffers in this module
42
+ missing_keys (list of str): if ``strict=True``, add missing keys to
43
+ this list
44
+ unexpected_keys (list of str): if ``strict=True``, add unexpected
45
+ keys to this list
46
+ error_msgs (list of str): error messages should be added to this
47
+ list, and will be reported together in
48
+ :meth:`~torch.nn.Module.load_state_dict`
49
+ """
50
+ for hook in self._load_state_dict_pre_hooks.values():
51
+ hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
52
+
53
+ persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
54
+ local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
55
+ local_state = {k: v for k, v in local_name_params if v is not None}
56
+ assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
57
+
58
+ for name, param in local_state.items():
59
+ key = prefix + name
60
+ if key in state_dict:
61
+ input_param = state_dict[key]
62
+ if not torch.overrides.is_tensor_like(input_param):
63
+ error_msgs.append('While copying the parameter named "{}", '
64
+ 'expected torch.Tensor or Tensor-like object from checkpoint but '
65
+ 'received {}'
66
+ .format(key, type(input_param)))
67
+ continue
68
+
69
+ # This is used to avoid copying uninitialized parameters into
70
+ # non-lazy modules, since they dont have the hook to do the checks
71
+ # in such case, it will error when accessing the .shape attribute.
72
+ is_param_lazy = torch.nn.parameter.is_lazy(param)
73
+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
74
+ if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
75
+ input_param = input_param[0]
76
+
77
+ if not is_param_lazy and input_param.shape != param.shape:
78
+ # local shape should match the one in checkpoint
79
+ error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
80
+ 'the shape in current model is {}.'
81
+ .format(key, input_param.shape, param.shape))
82
+ continue
83
+ try:
84
+ with torch.no_grad():
85
+ if assign_to_params_buffers:
86
+ # Shape checks are already done above
87
+ if (isinstance(param, torch.nn.Parameter) and
88
+ not isinstance(input_param, torch.nn.Parameter)):
89
+ setattr(self, name, torch.nn.Parameter(input_param))
90
+ else:
91
+ setattr(self, name, input_param)
92
+ else:
93
+ param.copy_(input_param)
94
+ except Exception as ex:
95
+ error_msgs.append('While copying the parameter named "{}", '
96
+ 'whose dimensions in the model are {} and '
97
+ 'whose dimensions in the checkpoint are {}, '
98
+ 'an exception occurred : {}.'
99
+ .format(key, param.size(), input_param.size(), ex.args))
100
+ elif strict:
101
+ missing_keys.append(key)
102
+
103
+ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
104
+ if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
105
+ if extra_state_key in state_dict:
106
+ self.set_extra_state(state_dict[extra_state_key])
107
+ elif strict:
108
+ missing_keys.append(extra_state_key)
109
+ elif strict and (extra_state_key in state_dict):
110
+ unexpected_keys.append(extra_state_key)
111
+
112
+ if strict:
113
+ for key in state_dict.keys():
114
+ if key.startswith(prefix) and key != extra_state_key:
115
+ input_name = key[len(prefix):]
116
+ input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
117
+ if input_name not in self._modules and input_name not in local_state:
118
+ unexpected_keys.append(key)
119
+
120
+ def load_state_dict(self, state_dict: Mapping[str, Any],
121
+ strict: bool = True, assign: bool = False):
122
+ r"""Copies parameters and buffers from :attr:`state_dict` into
123
+ this module and its descendants. If :attr:`strict` is ``True``, then
124
+ the keys of :attr:`state_dict` must exactly match the keys returned
125
+ by this module's :meth:`~torch.nn.Module.state_dict` function.
126
+
127
+ .. warning::
128
+ If :attr:`assign` is ``True`` the optimizer must be created after
129
+ the call to :attr:`load_state_dict`.
130
+
131
+ Args:
132
+ state_dict (dict): a dict containing parameters and
133
+ persistent buffers.
134
+ strict (bool, optional): whether to strictly enforce that the keys
135
+ in :attr:`state_dict` match the keys returned by this module's
136
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
137
+ assign (bool, optional): whether to assign items in the state
138
+ dictionary to their corresponding keys in the module instead
139
+ of copying them inplace into the module's current parameters and buffers.
140
+ When ``False``, the properties of the tensors in the current
141
+ module are preserved while when ``True``, the properties of the
142
+ Tensors in the state dict are preserved.
143
+ Default: ``False``
144
+
145
+ Returns:
146
+ ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
147
+ * **missing_keys** is a list of str containing the missing keys
148
+ * **unexpected_keys** is a list of str containing the unexpected keys
149
+
150
+ Note:
151
+ If a parameter or buffer is registered as ``None`` and its corresponding key
152
+ exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
153
+ ``RuntimeError``.
154
+ """
155
+ if not isinstance(state_dict, Mapping):
156
+ raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
157
+
158
+ missing_keys: List[str] = []
159
+ unexpected_keys: List[str] = []
160
+ error_msgs: List[str] = []
161
+
162
+ # copy state_dict so _load_from_state_dict can modify it
163
+ metadata = getattr(state_dict, '_metadata', None)
164
+ state_dict = OrderedDict(state_dict)
165
+ if metadata is not None:
166
+ # mypy isn't aware that "_metadata" exists in state_dict
167
+ state_dict._metadata = metadata # type: ignore[attr-defined]
168
+
169
+ def load(module, local_state_dict, prefix=''):
170
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
171
+ if assign:
172
+ local_metadata['assign_to_params_buffers'] = assign
173
+ module._load_from_state_dict(
174
+ local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
175
+ for name, child in module._modules.items():
176
+ if child is not None:
177
+ child_prefix = prefix + name + '.'
178
+ child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
179
+ load(child, child_state_dict, child_prefix)
180
+
181
+ # Note that the hook can modify missing_keys and unexpected_keys.
182
+ incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
183
+ for hook in module._load_state_dict_post_hooks.values():
184
+ out = hook(module, incompatible_keys)
185
+ assert out is None, (
186
+ "Hooks registered with ``register_load_state_dict_post_hook`` are not"
187
+ "expected to return new values, if incompatible_keys need to be modified,"
188
+ "it should be done inplace."
189
+ )
190
+
191
+ load(self, state_dict)
192
+ del load
193
+
194
+ if strict:
195
+ if len(unexpected_keys) > 0:
196
+ error_msgs.insert(
197
+ 0, 'Unexpected key(s) in state_dict: {}. '.format(
198
+ ', '.join('"{}"'.format(k) for k in unexpected_keys)))
199
+ if len(missing_keys) > 0:
200
+ error_msgs.insert(
201
+ 0, 'Missing key(s) in state_dict: {}. '.format(
202
+ ', '.join('"{}"'.format(k) for k in missing_keys)))
203
+
204
+ if len(error_msgs) > 0:
205
+ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
206
+ self.__class__.__name__, "\n\t".join(error_msgs)))
207
+ return _IncompatibleKeys(missing_keys, unexpected_keys)
208
+
209
+ if [int(x) for x in torch.__version__.split('.')[0:2]] < [2, 1]:
210
+ Module._load_from_state_dict = _load_from_state_dict
211
+ Module.load_state_dict = load_state_dict
212
+
213
+ # this patch is for adding the `assign` key to load_state_dict.
214
+ # the code is in pytorch source for version 2.1
215
+
216
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
217
+ missing_keys, unexpected_keys, error_msgs):
218
+ r"""Copies parameters and buffers from :attr:`state_dict` into only
219
+ this module, but not its descendants. This is called on every submodule
220
+ in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
221
+ module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
222
+ For state dicts without metadata, :attr:`local_metadata` is empty.
223
+ Subclasses can achieve class-specific backward compatible loading using
224
+ the version number at `local_metadata.get("version", None)`.
225
+ Additionally, :attr:`local_metadata` can also contain the key
226
+ `assign_to_params_buffers` that indicates whether keys should be
227
+ assigned their corresponding tensor in the state_dict.
228
+
229
+ .. note::
230
+ :attr:`state_dict` is not the same object as the input
231
+ :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
232
+ it can be modified.
233
+
234
+ Args:
235
+ state_dict (dict): a dict containing parameters and
236
+ persistent buffers.
237
+ prefix (str): the prefix for parameters and buffers used in this
238
+ module
239
+ local_metadata (dict): a dict containing the metadata for this module.
240
+ See
241
+ strict (bool): whether to strictly enforce that the keys in
242
+ :attr:`state_dict` with :attr:`prefix` match the names of
243
+ parameters and buffers in this module
244
+ missing_keys (list of str): if ``strict=True``, add missing keys to
245
+ this list
246
+ unexpected_keys (list of str): if ``strict=True``, add unexpected
247
+ keys to this list
248
+ error_msgs (list of str): error messages should be added to this
249
+ list, and will be reported together in
250
+ :meth:`~torch.nn.Module.load_state_dict`
251
+ """
252
+ for hook in self._load_state_dict_pre_hooks.values():
253
+ hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
254
+
255
+ persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
256
+ local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
257
+ local_state = {k: v for k, v in local_name_params if v is not None}
258
+ assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
259
+
260
+ for name, param in local_state.items():
261
+ key = prefix + name
262
+ if key in state_dict:
263
+ input_param = state_dict[key]
264
+ if not torch.overrides.is_tensor_like(input_param):
265
+ error_msgs.append('While copying the parameter named "{}", '
266
+ 'expected torch.Tensor or Tensor-like object from checkpoint but '
267
+ 'received {}'
268
+ .format(key, type(input_param)))
269
+ continue
270
+
271
+ # This is used to avoid copying uninitialized parameters into
272
+ # non-lazy modules, since they dont have the hook to do the checks
273
+ # in such case, it will error when accessing the .shape attribute.
274
+ is_param_lazy = torch.nn.parameter.is_lazy(param)
275
+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
276
+ if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
277
+ input_param = input_param[0]
278
+
279
+ if not is_param_lazy and input_param.shape != param.shape:
280
+ # local shape should match the one in checkpoint
281
+ error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
282
+ 'the shape in current model is {}.'
283
+ .format(key, input_param.shape, param.shape))
284
+ continue
285
+ try:
286
+ with torch.no_grad():
287
+ if assign_to_params_buffers:
288
+ # Shape checks are already done above
289
+ if (isinstance(param, torch.nn.Parameter) and
290
+ not isinstance(input_param, torch.nn.Parameter)):
291
+ setattr(self, name, torch.nn.Parameter(input_param))
292
+ else:
293
+ setattr(self, name, input_param)
294
+ else:
295
+ param.copy_(input_param)
296
+ except Exception as ex:
297
+ error_msgs.append('While copying the parameter named "{}", '
298
+ 'whose dimensions in the model are {} and '
299
+ 'whose dimensions in the checkpoint are {}, '
300
+ 'an exception occurred : {}.'
301
+ .format(key, param.size(), input_param.size(), ex.args))
302
+ elif strict:
303
+ missing_keys.append(key)
304
+
305
+ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
306
+ if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
307
+ if extra_state_key in state_dict:
308
+ self.set_extra_state(state_dict[extra_state_key])
309
+ elif strict:
310
+ missing_keys.append(extra_state_key)
311
+ elif strict and (extra_state_key in state_dict):
312
+ unexpected_keys.append(extra_state_key)
313
+
314
+ if strict:
315
+ for key in state_dict.keys():
316
+ if key.startswith(prefix) and key != extra_state_key:
317
+ input_name = key[len(prefix):]
318
+ input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
319
+ if input_name not in self._modules and input_name not in local_state:
320
+ unexpected_keys.append(key)
321
+
322
+ def load_state_dict(self, state_dict: Mapping[str, Any],
323
+ strict: bool = True, assign: bool = False):
324
+ r"""Copies parameters and buffers from :attr:`state_dict` into
325
+ this module and its descendants. If :attr:`strict` is ``True``, then
326
+ the keys of :attr:`state_dict` must exactly match the keys returned
327
+ by this module's :meth:`~torch.nn.Module.state_dict` function.
328
+
329
+ .. warning::
330
+ If :attr:`assign` is ``True`` the optimizer must be created after
331
+ the call to :attr:`load_state_dict`.
332
+
333
+ Args:
334
+ state_dict (dict): a dict containing parameters and
335
+ persistent buffers.
336
+ strict (bool, optional): whether to strictly enforce that the keys
337
+ in :attr:`state_dict` match the keys returned by this module's
338
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
339
+ assign (bool, optional): whether to assign items in the state
340
+ dictionary to their corresponding keys in the module instead
341
+ of copying them inplace into the module's current parameters and buffers.
342
+ When ``False``, the properties of the tensors in the current
343
+ module are preserved while when ``True``, the properties of the
344
+ Tensors in the state dict are preserved.
345
+ Default: ``False``
346
+
347
+ Returns:
348
+ ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
349
+ * **missing_keys** is a list of str containing the missing keys
350
+ * **unexpected_keys** is a list of str containing the unexpected keys
351
+
352
+ Note:
353
+ If a parameter or buffer is registered as ``None`` and its corresponding key
354
+ exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
355
+ ``RuntimeError``.
356
+ """
357
+ if not isinstance(state_dict, Mapping):
358
+ raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
359
+
360
+ missing_keys: List[str] = []
361
+ unexpected_keys: List[str] = []
362
+ error_msgs: List[str] = []
363
+
364
+ # copy state_dict so _load_from_state_dict can modify it
365
+ metadata = getattr(state_dict, '_metadata', None)
366
+ state_dict = OrderedDict(state_dict)
367
+ if metadata is not None:
368
+ # mypy isn't aware that "_metadata" exists in state_dict
369
+ state_dict._metadata = metadata # type: ignore[attr-defined]
370
+
371
+ def load(module, local_state_dict, prefix=''):
372
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
373
+ if assign:
374
+ local_metadata['assign_to_params_buffers'] = assign
375
+ module._load_from_state_dict(
376
+ local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
377
+ for name, child in module._modules.items():
378
+ if child is not None:
379
+ child_prefix = prefix + name + '.'
380
+ child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
381
+ load(child, child_state_dict, child_prefix)
382
+
383
+ # Note that the hook can modify missing_keys and unexpected_keys.
384
+ incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
385
+ for hook in module._load_state_dict_post_hooks.values():
386
+ out = hook(module, incompatible_keys)
387
+ assert out is None, (
388
+ "Hooks registered with ``register_load_state_dict_post_hook`` are not"
389
+ "expected to return new values, if incompatible_keys need to be modified,"
390
+ "it should be done inplace."
391
+ )
392
+
393
+ load(self, state_dict)
394
+ del load
395
+
396
+ if strict:
397
+ if len(unexpected_keys) > 0:
398
+ error_msgs.insert(
399
+ 0, 'Unexpected key(s) in state_dict: {}. '.format(
400
+ ', '.join('"{}"'.format(k) for k in unexpected_keys)))
401
+ if len(missing_keys) > 0:
402
+ error_msgs.insert(
403
+ 0, 'Missing key(s) in state_dict: {}. '.format(
404
+ ', '.join('"{}"'.format(k) for k in missing_keys)))
405
+
406
+ if len(error_msgs) > 0:
407
+ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
408
+ self.__class__.__name__, "\n\t".join(error_msgs)))
409
+ return _IncompatibleKeys(missing_keys, unexpected_keys)
410
+
411
+ if [int(x) for x in torch.__version__.split('.')[0:2]] < [2, 1]:
412
+ Module._load_from_state_dict = _load_from_state_dict
413
+ Module.load_state_dict = load_state_dict
414
+
415
+ # fmt: on
sdxl.py ADDED
@@ -0,0 +1,962 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import os
3
+ import random
4
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import safetensors.torch
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms
11
+ import torchvision.transforms.functional as TF
12
+ import wandb
13
+ import webdataset as wds
14
+ from PIL import Image
15
+ from torch.nn.parallel import DistributedDataParallel as DDP
16
+ from torch.utils.data import default_collate
17
+ from transformers import (CLIPTextModel, CLIPTextModelWithProjection,
18
+ CLIPTokenizerFast)
19
+
20
+ from diffusion import (default_num_train_timesteps,
21
+ euler_ode_solver_diffusion_loop, make_sigmas)
22
+ from sdxl_models import (SDXLAdapter, SDXLControlNet, SDXLControlNetFull,
23
+ SDXLControlNetPreEncodedControlnetCond, SDXLUNet,
24
+ SDXLVae)
25
+
26
+
27
+ class SDXLTraining:
28
+ text_encoder_one: CLIPTextModel
29
+ text_encoder_two: CLIPTextModelWithProjection
30
+ vae: SDXLVae
31
+ sigmas: torch.Tensor
32
+ unet: SDXLUNet
33
+ adapter: Optional[SDXLAdapter]
34
+ controlnet: Optional[Union[SDXLControlNet, SDXLControlNetFull]]
35
+
36
+ train_unet: bool
37
+ train_unet_up_blocks: bool
38
+
39
+ mixed_precision: Optional[torch.dtype]
40
+ timestep_sampling: Literal["uniform", "cubic"]
41
+
42
+ validation_images_logged: bool
43
+ log_validation_input_images_every_time: bool
44
+
45
+ get_sdxl_conditioning_images: Callable[[Image.Image], Dict[str, Any]]
46
+
47
+ def __init__(
48
+ self,
49
+ device,
50
+ train_unet,
51
+ get_sdxl_conditioning_images,
52
+ train_unet_up_blocks=False,
53
+ unet_resume_from=None,
54
+ controlnet_cls=None,
55
+ controlnet_resume_from=None,
56
+ adapter_cls=None,
57
+ adapter_resume_from=None,
58
+ mixed_precision=None,
59
+ timestep_sampling="uniform",
60
+ log_validation_input_images_every_time=True,
61
+ ):
62
+ self.text_encoder_one = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", variant="fp16", torch_dtype=torch.float16)
63
+ self.text_encoder_one.to(device=device)
64
+ self.text_encoder_one.requires_grad_(False)
65
+ self.text_encoder_one.eval()
66
+
67
+ self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2", variant="fp16", torch_dtype=torch.float16)
68
+ self.text_encoder_two.to(device=device)
69
+ self.text_encoder_two.requires_grad_(False)
70
+ self.text_encoder_two.eval()
71
+
72
+ self.vae = SDXLVae.load_fp16_fix(device=device)
73
+ self.vae.requires_grad_(False)
74
+ self.vae.eval()
75
+
76
+ self.sigmas = make_sigmas(device=device)
77
+
78
+ if train_unet:
79
+ if unet_resume_from is None:
80
+ self.unet = SDXLUNet.load_fp32(device=device)
81
+ else:
82
+ self.unet = SDXLUNet.load(unet_resume_from, device=device)
83
+ self.unet.requires_grad_(True)
84
+ self.unet.train()
85
+ self.unet = DDP(self.unet, device_ids=[device])
86
+ elif train_unet_up_blocks:
87
+ if unet_resume_from is None:
88
+ self.unet = SDXLUNet.load_fp32(device=device)
89
+ else:
90
+ self.unet = SDXLUNet.load_fp32(device=device, overrides=[unet_resume_from])
91
+ self.unet.requires_grad_(False)
92
+ self.unet.eval()
93
+ self.unet.up_blocks.requires_grad_(True)
94
+ self.unet.up_blocks.train()
95
+ self.unet = DDP(self.unet, device_ids=[device], find_unused_parameters=True)
96
+ else:
97
+ self.unet = SDXLUNet.load_fp16(device=device)
98
+ self.unet.requires_grad_(False)
99
+ self.unet.eval()
100
+
101
+ if controlnet_cls is not None:
102
+ if controlnet_resume_from is None:
103
+ self.controlnet = controlnet_cls.from_unet(self.unet)
104
+ self.controlnet.to(device)
105
+ else:
106
+ self.controlnet = controlnet_cls.load(controlnet_resume_from, device=device)
107
+ self.controlnet.train()
108
+ self.controlnet.requires_grad_(True)
109
+ # TODO add back
110
+ # controlnet.enable_gradient_checkpointing()
111
+ # TODO - should be able to remove find_unused_parameters. Comes from pre encoded controlnet
112
+ self.controlnet = DDP(self.controlnet, device_ids=[device], find_unused_parameters=True)
113
+ else:
114
+ self.controlnet = None
115
+
116
+ if adapter_cls is not None:
117
+ if adapter_resume_from is None:
118
+ self.adapter = adapter_cls()
119
+ self.adapter.to(device=device)
120
+ else:
121
+ self.adapter = adapter_cls.load(adapter_resume_from, device=device)
122
+ self.adapter.train()
123
+ self.adapter.requires_grad_(True)
124
+ self.adapter = DDP(self.adapter, device_ids=[device])
125
+ else:
126
+ self.adapter = None
127
+
128
+ self.mixed_precision = mixed_precision
129
+ self.timestep_sampling = timestep_sampling
130
+
131
+ self.validation_images_logged = False
132
+ self.log_validation_input_images_every_time = log_validation_input_images_every_time
133
+
134
+ self.get_sdxl_conditioning_images = get_sdxl_conditioning_images
135
+
136
+ self.train_unet = train_unet
137
+ self.train_unet_up_blocks = train_unet_up_blocks
138
+
139
+ def train_step(self, batch):
140
+ with torch.no_grad():
141
+ if isinstance(self.unet, DDP):
142
+ unet_dtype = self.unet.module.dtype
143
+ unet_device = self.unet.module.device
144
+ else:
145
+ unet_dtype = self.unet.dtype
146
+ unet_device = self.unet.device
147
+
148
+ micro_conditioning = batch["micro_conditioning"].to(device=unet_device)
149
+
150
+ image = batch["image"].to(self.vae.device, dtype=self.vae.dtype)
151
+ latents = self.vae.encode(image).to(dtype=unet_dtype)
152
+
153
+ text_input_ids_one = batch["text_input_ids_one"].to(self.text_encoder_one.device)
154
+ text_input_ids_two = batch["text_input_ids_two"].to(self.text_encoder_two.device)
155
+
156
+ encoder_hidden_states, pooled_encoder_hidden_states = sdxl_text_conditioning(self.text_encoder_one, self.text_encoder_two, text_input_ids_one, text_input_ids_two)
157
+
158
+ encoder_hidden_states = encoder_hidden_states.to(dtype=unet_dtype)
159
+ pooled_encoder_hidden_states = pooled_encoder_hidden_states.to(dtype=unet_dtype)
160
+
161
+ bsz = latents.shape[0]
162
+
163
+ if self.timestep_sampling == "uniform":
164
+ timesteps = torch.randint(0, default_num_train_timesteps, (bsz,), device=unet_device)
165
+ elif self.timestep_sampling == "cubic":
166
+ # Cubic sampling to sample a random timestep for each image
167
+ timesteps = torch.rand((bsz,), device=unet_device)
168
+ timesteps = (1 - timesteps**3) * default_num_train_timesteps
169
+ timesteps = timesteps.long()
170
+ timesteps = timesteps.clamp(0, default_num_train_timesteps - 1)
171
+ else:
172
+ assert False
173
+
174
+ sigmas_ = self.sigmas[timesteps].to(dtype=latents.dtype)
175
+
176
+ noise = torch.randn_like(latents)
177
+
178
+ noisy_latents = latents + noise * sigmas_
179
+
180
+ scaled_noisy_latents = noisy_latents / ((sigmas_**2 + 1) ** 0.5)
181
+
182
+ if "conditioning_image" in batch:
183
+ conditioning_image = batch["conditioning_image"].to(unet_device)
184
+
185
+ if self.controlnet is not None and isinstance(self.controlnet, SDXLControlNetPreEncodedControlnetCond):
186
+ controlnet_device = self.controlnet.module.device
187
+ controlnet_dtype = self.controlnet.module.dtype
188
+ conditioning_image = self.vae.encode(conditioning_image.to(self.vae.dtype)).to(device=controlnet_device, dtype=controlnet_dtype)
189
+ conditioning_image_mask = TF.resize(batch["conditioning_image_mask"], conditioning_image.shape[2:]).to(device=controlnet_device, dtype=controlnet_dtype)
190
+ conditioning_image = torch.concat((conditioning_image, conditioning_image_mask), dim=1)
191
+
192
+ with torch.autocast(
193
+ "cuda",
194
+ self.mixed_precision,
195
+ enabled=self.mixed_precision is not None,
196
+ ):
197
+ down_block_additional_residuals = None
198
+ mid_block_additional_residual = None
199
+ add_to_down_block_inputs = None
200
+ add_to_output = None
201
+
202
+ if self.adapter is not None:
203
+ down_block_additional_residuals = self.adapter(conditioning_image)
204
+
205
+ if self.controlnet is not None:
206
+ controlnet_out = self.controlnet(
207
+ x_t=scaled_noisy_latents,
208
+ t=timesteps,
209
+ encoder_hidden_states=encoder_hidden_states,
210
+ micro_conditioning=micro_conditioning,
211
+ pooled_encoder_hidden_states=pooled_encoder_hidden_states,
212
+ controlnet_cond=conditioning_image,
213
+ )
214
+
215
+ down_block_additional_residuals = controlnet_out["down_block_res_samples"]
216
+ mid_block_additional_residual = controlnet_out["mid_block_res_sample"]
217
+ add_to_down_block_inputs = controlnet_out.get("add_to_down_block_inputs", None)
218
+ add_to_output = controlnet_out.get("add_to_output", None)
219
+
220
+ model_pred = self.unet(
221
+ x_t=scaled_noisy_latents,
222
+ t=timesteps,
223
+ encoder_hidden_states=encoder_hidden_states,
224
+ micro_conditioning=micro_conditioning,
225
+ pooled_encoder_hidden_states=pooled_encoder_hidden_states,
226
+ down_block_additional_residuals=down_block_additional_residuals,
227
+ mid_block_additional_residual=mid_block_additional_residual,
228
+ add_to_down_block_inputs=add_to_down_block_inputs,
229
+ add_to_output=add_to_output,
230
+ ).sample
231
+
232
+ loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
233
+
234
+ return loss
235
+
236
+ @torch.no_grad()
237
+ def log_validation(self, step, num_validation_images: int, validation_prompts: Optional[List[str]] = None, validation_images: Optional[List[str]] = None):
238
+ if isinstance(self.unet, DDP):
239
+ unet = self.unet.module
240
+ unet.eval()
241
+ unet_set_to_eval = True
242
+ else:
243
+ unet = self.unet
244
+ unet_set_to_eval = False
245
+
246
+ if self.adapter is not None:
247
+ adapter = self.adapter.module
248
+ adapter.eval()
249
+ else:
250
+ adapter = None
251
+
252
+ if self.controlnet is not None:
253
+ controlnet = self.controlnet.module
254
+ controlnet.eval()
255
+ else:
256
+ controlnet = None
257
+
258
+ formatted_validation_images = None
259
+
260
+ if validation_images is not None:
261
+ formatted_validation_images = []
262
+ wandb_validation_images = []
263
+
264
+ for validation_image_path in validation_images:
265
+ validation_image = Image.open(validation_image_path)
266
+ validation_image = validation_image.convert("RGB")
267
+ validation_image = validation_image.resize((1024, 1024))
268
+
269
+ conditioning_images = self.get_sdxl_conditioning_images(validation_image)
270
+
271
+ conditioning_image = conditioning_images["conditioning_image"]
272
+
273
+ if self.controlnet is not None and isinstance(self.controlnet, SDXLControlNetPreEncodedControlnetCond):
274
+ conditioning_image = self.vae.encode(conditioning_image[None, :, :, :].to(self.vae.device, dtype=self.vae.dtype))
275
+ conditionin_mask_image = TF.resize(conditioning_images["conditioning_mask_image"], conditioning_image.shape[2:]).to(conditioning_image.dtype, conditioning_image.device)
276
+ conditioning_image = torch.concat(conditioning_image, conditionin_mask_image, dim=1)
277
+
278
+ formatted_validation_images.append(conditioning_image)
279
+ wandb_validation_images.append(wandb.Image(conditioning_images["conditioning_image_as_pil"]))
280
+
281
+ if self.log_validation_input_images_every_time or not self.validation_images_logged:
282
+ wandb.log({"validation_conditioning": wandb_validation_images}, step=step)
283
+ self.validation_images_logged = True
284
+
285
+ generator = torch.Generator().manual_seed(0)
286
+
287
+ output_validation_images = []
288
+
289
+ for formatted_validation_image, validation_prompt in zip(formatted_validation_images, validation_prompts):
290
+ for _ in range(num_validation_images):
291
+ with torch.autocast("cuda"):
292
+ x_0 = sdxl_diffusion_loop(
293
+ prompts=validation_prompt,
294
+ images=formatted_validation_image,
295
+ unet=unet,
296
+ text_encoder_one=self.text_encoder_one,
297
+ text_encoder_two=self.text_encoder_two,
298
+ controlnet=controlnet,
299
+ adapter=adapter,
300
+ sigmas=self.sigmas,
301
+ generator=generator,
302
+ )
303
+
304
+ x_0 = self.vae.decode(x_0)
305
+ x_0 = self.vae.output_tensor_to_pil(x_0)[0]
306
+
307
+ output_validation_images.append(wandb.Image(x_0, caption=validation_prompt))
308
+
309
+ wandb.log({"validation": output_validation_images}, step=step)
310
+
311
+ if unet_set_to_eval:
312
+ unet.train()
313
+
314
+ if adapter is not None:
315
+ adapter.train()
316
+
317
+ if controlnet is not None:
318
+ controlnet.train()
319
+
320
+ def parameters(self):
321
+ if self.train_unet:
322
+ return self.unet.parameters()
323
+
324
+ if self.controlnet is not None and self.train_unet_up_blocks:
325
+ return itertools.chain(self.controlnet.parameters(), self.unet.up_blocks.parameters())
326
+
327
+ if self.controlnet is not None:
328
+ return self.controlnet.parameters()
329
+
330
+ if self.adapter is not None:
331
+ return self.adapter.parameters()
332
+
333
+ assert False
334
+
335
+ def save(self, save_to):
336
+ if self.train_unet:
337
+ safetensors.torch.save_file(self.unet.module.state_dict(), os.path.join(save_to, "unet.safetensors"))
338
+
339
+ if self.controlnet is not None and self.train_unet_up_blocks:
340
+ safetensors.torch.save_file(self.controlnet.module.state_dict(), os.path.join(save_to, "controlnet.safetensors"))
341
+ safetensors.torch.save_file(self.unet.module.up_blocks.state_dict(), os.path.join(save_to, "unet.safetensors"))
342
+
343
+ if self.controlnet is not None:
344
+ safetensors.torch.save_file(self.controlnet.module.state_dict(), os.path.join(save_to, "controlnet.safetensors"))
345
+
346
+ if self.adapter is not None:
347
+ safetensors.torch.save_file(self.adapter.module.state_dict(), os.path.join(save_to, "adapter.safetensors"))
348
+
349
+
350
+ def get_sdxl_dataset(train_shards: str, shuffle_buffer_size: int, batch_size: int, proportion_empty_prompts: float, get_sdxl_conditioning_images=None):
351
+ dataset = (
352
+ wds.WebDataset(
353
+ train_shards,
354
+ resampled=True,
355
+ handler=wds.ignore_and_continue,
356
+ )
357
+ .shuffle(shuffle_buffer_size)
358
+ .decode("pil", handler=wds.ignore_and_continue)
359
+ .rename(
360
+ image="jpg;png;jpeg;webp",
361
+ text="text;txt;caption",
362
+ metadata="json",
363
+ handler=wds.warn_and_continue,
364
+ )
365
+ .map(lambda d: make_sample(d, proportion_empty_prompts=proportion_empty_prompts, get_sdxl_conditioning_images=get_sdxl_conditioning_images))
366
+ .select(lambda sample: "conditioning_image" not in sample or sample["conditioning_image"] is not None)
367
+ )
368
+
369
+ dataset = dataset.batched(batch_size, partial=False, collation_fn=default_collate)
370
+
371
+ return dataset
372
+
373
+
374
+ @torch.no_grad()
375
+ def make_sample(d, proportion_empty_prompts, get_sdxl_conditioning_images=None):
376
+ image = d["image"]
377
+ metadata = d["metadata"]
378
+
379
+ if random.random() < proportion_empty_prompts:
380
+ text = ""
381
+ else:
382
+ text = d["text"]
383
+
384
+ c_top, c_left, _, _ = get_random_crop_params([image.height, image.width], [1024, 1024])
385
+
386
+ original_width = int(metadata.get("original_width", 0.0))
387
+ original_height = int(metadata.get("original_height", 0.0))
388
+
389
+ micro_conditioning = torch.tensor([original_width, original_height, c_top, c_left, 1024, 1024])
390
+
391
+ text_input_ids_one = sdxl_tokenize_one(text)
392
+
393
+ text_input_ids_two = sdxl_tokenize_two(text)
394
+
395
+ image = image.convert("RGB")
396
+
397
+ image = TF.resize(
398
+ image,
399
+ 1024,
400
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
401
+ )
402
+
403
+ image = TF.crop(
404
+ image,
405
+ c_top,
406
+ c_left,
407
+ 1024,
408
+ 1024,
409
+ )
410
+
411
+ sample = {
412
+ "micro_conditioning": micro_conditioning,
413
+ "text_input_ids_one": text_input_ids_one,
414
+ "text_input_ids_two": text_input_ids_two,
415
+ "image": SDXLVae.input_pil_to_tensor(image),
416
+ }
417
+
418
+ if get_sdxl_conditioning_images is not None:
419
+ conditioning_images = get_sdxl_conditioning_images(image)
420
+
421
+ sample["conditioning_image"] = conditioning_images["conditioning_image"]
422
+
423
+ if conditioning_images["conditioning_image_mask"] is not None:
424
+ sample["conditioning_image_mask"] = conditioning_images["conditioning_image_mask"]
425
+
426
+ return sample
427
+
428
+
429
+ def get_random_crop_params(input_size: Tuple[int, int], output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
430
+ h, w = input_size
431
+
432
+ th, tw = output_size
433
+
434
+ if h < th or w < tw:
435
+ raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
436
+
437
+ if w == tw and h == th:
438
+ return 0, 0, h, w
439
+
440
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
441
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
442
+
443
+ return i, j, th, tw
444
+
445
+
446
+ def get_sdxl_conditioning_images(image, adapter_type=None, controlnet_type=None, controlnet_variant=None, open_pose=None, conditioning_image_mask=None):
447
+ resolution = image.width
448
+
449
+ if adapter_type == "openpose":
450
+ conditioning_image = open_pose(image, detect_resolution=resolution, image_resolution=resolution, return_pil=False)
451
+
452
+ if (conditioning_image == 0).all():
453
+ return None, None
454
+
455
+ conditioning_image_as_pil = Image.fromarray(conditioning_image)
456
+
457
+ conditioning_image = TF.to_tensor(conditioning_image)
458
+
459
+ if controlnet_type == "canny":
460
+ import cv2
461
+
462
+ conditioning_image = np.array(image)
463
+ conditioning_image = cv2.Canny(conditioning_image, 100, 200)
464
+ conditioning_image = conditioning_image[:, :, None]
465
+ conditioning_image = np.concatenate([conditioning_image, conditioning_image, conditioning_image], axis=2)
466
+
467
+ conditioning_image_as_pil = Image.fromarray(conditioning_image)
468
+
469
+ conditioning_image = TF.to_tensor(conditioning_image)
470
+
471
+ if controlnet_type == "inpainting":
472
+ if conditioning_image_mask is None:
473
+ if random.random() <= 0.25:
474
+ conditioning_image_mask = np.ones((resolution, resolution), np.float32)
475
+ else:
476
+ conditioning_image_mask = random.choice([make_random_rectangle_mask, make_random_irregular_mask, make_outpainting_mask])(resolution, resolution)
477
+
478
+ conditioning_image_mask = torch.from_numpy(conditioning_image_mask)
479
+
480
+ conditioning_image_mask = conditioning_image_mask[None, :, :]
481
+
482
+ conditioning_image = TF.to_tensor(image)
483
+
484
+ if controlnet_variant == "pre_encoded_controlnet_cond":
485
+ # where mask is 1, zero out the pixels. Note that this requires mask to be concattenated
486
+ # with the mask so that the network knows the zeroed out pixels are from the mask and
487
+ # are not just zero in the original image
488
+ conditioning_image = conditioning_image * (conditioning_image_mask < 0.5)
489
+
490
+ conditioning_image_as_pil = TF.to_pil_image(conditioning_image)
491
+
492
+ conditioning_image = TF.normalize(conditioning_image, [0.5], [0.5])
493
+ else:
494
+ # Just zero out the pixels which will be masked
495
+ conditioning_image_as_pil = TF.to_pil_image(conditioning_image * (conditioning_image_mask < 0.5))
496
+
497
+ # where mask is set to 1, set to -1 "special" masked image pixel.
498
+ # -1 is outside of the 0-1 range that the controlnet normalized
499
+ # input is in.
500
+ conditioning_image = conditioning_image * (conditioning_image_mask < 0.5) + -1.0 * (conditioning_image_mask >= 0.5)
501
+
502
+ return dict(conditioning_image=conditioning_image, conditioning_image_mask=conditioning_image_mask, conditioning_image_as_pil=conditioning_image_as_pil)
503
+
504
+
505
+ # TODO: would be nice to just call a function from a tokenizers https://github.com/huggingface/tokenizers
506
+ # i.e. afaik tokenizing shouldn't require holding any state
507
+
508
+ tokenizer_one = CLIPTokenizerFast.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="tokenizer")
509
+
510
+ tokenizer_two = CLIPTokenizerFast.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="tokenizer_2")
511
+
512
+
513
+ def sdxl_tokenize_one(prompts):
514
+ return tokenizer_one(
515
+ prompts,
516
+ padding="max_length",
517
+ max_length=tokenizer_one.model_max_length,
518
+ truncation=True,
519
+ return_tensors="pt",
520
+ ).input_ids[0]
521
+
522
+
523
+ def sdxl_tokenize_two(prompts):
524
+ return tokenizer_two(
525
+ prompts,
526
+ padding="max_length",
527
+ max_length=tokenizer_one.model_max_length,
528
+ truncation=True,
529
+ return_tensors="pt",
530
+ ).input_ids[0]
531
+
532
+
533
+ def sdxl_text_conditioning(text_encoder_one, text_encoder_two, text_input_ids_one, text_input_ids_two):
534
+ prompt_embeds_1 = text_encoder_one(
535
+ text_input_ids_one,
536
+ output_hidden_states=True,
537
+ ).hidden_states[-2]
538
+
539
+ prompt_embeds_1 = prompt_embeds_1.view(prompt_embeds_1.shape[0], prompt_embeds_1.shape[1], -1)
540
+
541
+ prompt_embeds_2 = text_encoder_two(
542
+ text_input_ids_two,
543
+ output_hidden_states=True,
544
+ )
545
+
546
+ pooled_encoder_hidden_states = prompt_embeds_2[0]
547
+
548
+ prompt_embeds_2 = prompt_embeds_2.hidden_states[-2]
549
+
550
+ prompt_embeds_2 = prompt_embeds_2.view(prompt_embeds_2.shape[0], prompt_embeds_2.shape[1], -1)
551
+
552
+ encoder_hidden_states = torch.cat((prompt_embeds_1, prompt_embeds_2), dim=-1)
553
+
554
+ return encoder_hidden_states, pooled_encoder_hidden_states
555
+
556
+
557
+ def make_random_rectangle_mask(
558
+ height,
559
+ width,
560
+ margin=10,
561
+ bbox_min_size=100,
562
+ bbox_max_size=512,
563
+ min_times=1,
564
+ max_times=2,
565
+ ):
566
+ mask = np.zeros((height, width), np.float32)
567
+
568
+ bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2)
569
+
570
+ times = np.random.randint(min_times, max_times + 1)
571
+
572
+ for i in range(times):
573
+ box_width = np.random.randint(bbox_min_size, bbox_max_size)
574
+ box_height = np.random.randint(bbox_min_size, bbox_max_size)
575
+
576
+ start_x = np.random.randint(margin, width - margin - box_width + 1)
577
+ start_y = np.random.randint(margin, height - margin - box_height + 1)
578
+
579
+ mask[start_y : start_y + box_height, start_x : start_x + box_width] = 1
580
+
581
+ return mask
582
+
583
+
584
+ def make_random_irregular_mask(height, width, max_angle=4, max_len=60, max_width=256, min_times=1, max_times=2):
585
+ import cv2
586
+
587
+ mask = np.zeros((height, width), np.float32)
588
+
589
+ times = np.random.randint(min_times, max_times + 1)
590
+
591
+ for i in range(times):
592
+ start_x = np.random.randint(width)
593
+ start_y = np.random.randint(height)
594
+
595
+ for j in range(1 + np.random.randint(5)):
596
+ angle = 0.01 + np.random.randint(max_angle)
597
+
598
+ if i % 2 == 0:
599
+ angle = 2 * 3.1415926 - angle
600
+
601
+ length = 10 + np.random.randint(max_len)
602
+
603
+ brush_w = 5 + np.random.randint(max_width)
604
+
605
+ end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width)
606
+ end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height)
607
+
608
+ choice = random.randint(0, 2)
609
+
610
+ if choice == 0:
611
+ cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w)
612
+ elif choice == 1:
613
+ cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1.0, thickness=-1)
614
+ elif choice == 2:
615
+ radius = brush_w // 2
616
+ mask[
617
+ start_y - radius : start_y + radius,
618
+ start_x - radius : start_x + radius,
619
+ ] = 1
620
+ else:
621
+ assert False
622
+
623
+ start_x, start_y = end_x, end_y
624
+
625
+ return mask
626
+
627
+
628
+ def make_outpainting_mask(height, width, probs=[0.5, 0.5, 0.5, 0.5]):
629
+ mask = np.zeros((height, width), np.float32)
630
+ at_least_one_mask_applied = False
631
+
632
+ coords = [
633
+ [(0, 0), (1, get_padding(height))],
634
+ [(0, 0), (get_padding(width), 1)],
635
+ [(0, 1 - get_padding(height)), (1, 1)],
636
+ [(1 - get_padding(width), 0), (1, 1)],
637
+ ]
638
+
639
+ for pp, coord in zip(probs, coords):
640
+ if np.random.random() < pp:
641
+ at_least_one_mask_applied = True
642
+ mask = apply_padding(mask=mask, coord=coord)
643
+
644
+ if not at_least_one_mask_applied:
645
+ idx = np.random.choice(range(len(coords)), p=np.array(probs) / sum(probs))
646
+ mask = apply_padding(mask=mask, coord=coords[idx])
647
+
648
+ return mask
649
+
650
+
651
+ def get_padding(size, min_padding_percent=0.04, max_padding_percent=0.5):
652
+ n1 = int(min_padding_percent * size)
653
+ n2 = int(max_padding_percent * size)
654
+ return np.random.randint(n1, n2) / size
655
+
656
+
657
+ def apply_padding(mask, coord):
658
+ height, width = mask.shape
659
+
660
+ mask[
661
+ int(coord[0][0] * height) : int(coord[1][0] * height),
662
+ int(coord[0][1] * width) : int(coord[1][1] * width),
663
+ ] = 1
664
+
665
+ return mask
666
+
667
+
668
+ @torch.no_grad()
669
+ def sdxl_diffusion_loop(
670
+ prompts,
671
+ unet,
672
+ text_encoder_one,
673
+ text_encoder_two,
674
+ images=None,
675
+ controlnet=None,
676
+ adapter=None,
677
+ sigmas=None,
678
+ timesteps=None,
679
+ x_T=None,
680
+ micro_conditioning=None,
681
+ guidance_scale=5.0,
682
+ generator=None,
683
+ negative_prompts=None,
684
+ diffusion_loop=euler_ode_solver_diffusion_loop,
685
+ ):
686
+ if negative_prompts is None:
687
+ negative_prompts = [""] * len(prompts)
688
+
689
+ prompts += negative_prompts
690
+
691
+ encoder_hidden_states, pooled_encoder_hidden_states = sdxl_text_conditioning(
692
+ text_encoder_one,
693
+ text_encoder_two,
694
+ sdxl_tokenize_one(prompts).to(text_encoder_one.device),
695
+ sdxl_tokenize_two(prompts).to(text_encoder_two.device),
696
+ )
697
+
698
+ if x_T is None:
699
+ x_T = torch.randn((1, 4, 1024 // 8, 1024 // 8), dtype=torch.float32, device=unet.device, generator=generator)
700
+ x_T = x_T * ((sigmas.max() ** 2 + 1) ** 0.5)
701
+
702
+ if sigmas is None:
703
+ sigmas = make_sigmas(device=unet.device)
704
+
705
+ if timesteps is None:
706
+ timesteps = torch.linspace(0, sigmas.numel(), 50, dtype=torch.long, device=unet.device)
707
+
708
+ if micro_conditioning is None:
709
+ micro_conditioning = torch.tensor([1024, 1024, 0, 0, 1024, 1024], dtype=torch.long, device=unet.device)
710
+
711
+ if adapter is not None:
712
+ down_block_additional_residuals = adapter(images)
713
+ else:
714
+ down_block_additional_residuals = None
715
+
716
+ if controlnet is not None:
717
+ controlnet_cond = images
718
+ else:
719
+ controlnet_cond = None
720
+
721
+ eps_theta = lambda x_t, t, sigma: sdxl_eps_theta(
722
+ x_t=x_t,
723
+ t=t,
724
+ sigma=sigma,
725
+ unet=unet,
726
+ encoder_hidden_states=encoder_hidden_states,
727
+ pooled_encoder_hidden_states=pooled_encoder_hidden_states,
728
+ micro_conditioning=micro_conditioning,
729
+ guidance_scale=guidance_scale,
730
+ controlnet=controlnet,
731
+ controlnet_cond=controlnet_cond,
732
+ down_block_additional_residuals=down_block_additional_residuals,
733
+ )
734
+
735
+ x_0 = diffusion_loop(eps_theta=eps_theta, timesteps=timesteps, sigmas=sigmas, x_T=x_T)
736
+
737
+ return x_0
738
+
739
+
740
+ @torch.no_grad()
741
+ def sdxl_eps_theta(
742
+ x_t,
743
+ t,
744
+ sigma,
745
+ unet,
746
+ encoder_hidden_states,
747
+ pooled_encoder_hidden_states,
748
+ micro_conditioning,
749
+ guidance_scale,
750
+ controlnet=None,
751
+ controlnet_cond=None,
752
+ down_block_additional_residuals=None,
753
+ ):
754
+ # TODO - how does this not effect the ode we are solving
755
+ scaled_x_t = x_t / ((sigma**2 + 1) ** 0.5)
756
+
757
+ if guidance_scale > 1.0:
758
+ scaled_x_t = torch.concat([scaled_x_t, scaled_x_t])
759
+
760
+ if controlnet is not None:
761
+ controlnet_out = controlnet(
762
+ x_t=scaled_x_t,
763
+ t=t,
764
+ encoder_hidden_states=encoder_hidden_states,
765
+ micro_conditioning=micro_conditioning,
766
+ pooled_encoder_hidden_states=pooled_encoder_hidden_states,
767
+ controlnet_cond=controlnet_cond,
768
+ )
769
+
770
+ down_block_additional_residuals = controlnet_out["down_block_res_samples"]
771
+ mid_block_additional_residual = controlnet_out["mid_block_res_sample"]
772
+ add_to_down_block_inputs = controlnet_out.get("add_to_down_block_inputs", None)
773
+ add_to_output = controlnet_out.get("add_to_output", None)
774
+ else:
775
+ mid_block_additional_residual = None
776
+ add_to_down_block_inputs = None
777
+ add_to_output = None
778
+
779
+ eps_hat = unet(
780
+ x_t=scaled_x_t,
781
+ t=t,
782
+ encoder_hidden_states=encoder_hidden_states,
783
+ micro_conditioning=micro_conditioning,
784
+ pooled_encoder_hidden_states=pooled_encoder_hidden_states,
785
+ down_block_additional_residuals=down_block_additional_residuals,
786
+ mid_block_additional_residual=mid_block_additional_residual,
787
+ add_to_down_block_inputs=add_to_down_block_inputs,
788
+ add_to_output=add_to_output,
789
+ )
790
+
791
+ if guidance_scale > 1.0:
792
+ eps_hat_uncond, eps_hat = eps_hat.chunk(2)
793
+
794
+ eps_hat = eps_hat_uncond + guidance_scale * (eps_hat - eps_hat_uncond)
795
+
796
+ return eps_hat
797
+
798
+ known_negative_prompt = "text, watermark, low-quality, signature, moiré pattern, downsampling, aliasing, distorted, blurry, glossy, blur, jpeg artifacts, compression artifacts, poorly drawn, low-resolution, bad, distortion, twisted, excessive, exaggerated pose, exaggerated limbs, grainy, symmetrical, duplicate, error, pattern, beginner, pixelated, fake, hyper, glitch, overexposed, high-contrast, bad-contrast"
799
+
800
+ def gen_sdxl_simplified_interface(
801
+ prompt:str,
802
+ negative_prompt: Optional[str] = None,
803
+ controlnet_checkpoint: Optional[str]=None,
804
+ controlnet: Optional[Literal["SDXLControlNet", "SDXLContolNetFull", "SDXLControlNetPreEncodedControlnetCond"]]=None,
805
+ adapter_checkpoint: Optional[str]=None,
806
+ num_inference_steps=50,
807
+ images=None,
808
+ masks=None,
809
+ apply_conditioning: Optional[Literal["canny"]]=None,
810
+ num_images: int=1,
811
+ device: Optional[str]=None,
812
+ text_encoder_one=None,
813
+ text_encoder_two=None,
814
+ unet=None,
815
+ vae=None,
816
+ ):
817
+ if device is None:
818
+ if torch.cuda.is_available():
819
+ device = "cuda"
820
+ elif torch.backends.mps.is_available():
821
+ device = "mps"
822
+
823
+ if text_encoder_one is None:
824
+ text_encoder_one = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", variant="fp16", torch_dtype=torch.float16)
825
+ text_encoder_one.to(device=device)
826
+
827
+ if text_encoder_two is None:
828
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2", variant="fp16", torch_dtype=torch.float16)
829
+ text_encoder_two.to(device=device)
830
+
831
+ if vae is None:
832
+ vae = SDXLVae.load_fp16_fix(device=device)
833
+
834
+ if unet is None:
835
+ unet = SDXLUNet.load_fp16(device=device)
836
+
837
+ if isinstance(controlnet, str) and controlnet_checkpoint is not None:
838
+ if controlnet == "SDXLControlNet":
839
+ controlnet = SDXLControlNet.load(controlnet_checkpoint, device=device, dtype=torch.float16)
840
+ elif controlnet == "SDXLControlNetFull":
841
+ controlnet = SDXLControlNetFull.load(controlnet_checkpoint, device=device, dtype=torch.float16)
842
+ elif controlnet == "SDXLControlNetPreEncodedControlnetCond":
843
+ controlnet = SDXLControlNetPreEncodedControlnetCond.load(controlnet_checkpoint, device=device, dtype=torch.float16)
844
+ else:
845
+ assert False
846
+
847
+ if adapter_checkpoint is not None:
848
+ adapter = SDXLAdapter.load(adapter_checkpoint, device=device, dtype=torch.float16)
849
+ else:
850
+ adapter = None
851
+
852
+ sigmas = make_sigmas()
853
+
854
+ timesteps = torch.linspace(0, sigmas.numel(), num_inference_steps, dtype=torch.long, device=unet.device)
855
+
856
+ if images is not None:
857
+ if not isinstance(images, list):
858
+ images = [images]
859
+
860
+ if masks is not None and not isinstance(masks, list):
861
+ masks = [masks]
862
+
863
+ images_ = []
864
+
865
+ for image_idx, image in enumerate(images):
866
+ if isinstance(image, str):
867
+ image = Image.open(image)
868
+ image = image.convert("RGB")
869
+ image = image.resize((1024, 1024))
870
+ elif isinstance(image, Image.Image):
871
+ ...
872
+ else:
873
+ assert False
874
+
875
+ if apply_conditioning == "canny":
876
+ import cv2
877
+
878
+ image = np.array(image)
879
+ image = cv2.Canny(image, 100, 200)
880
+ image = image[:, :, None]
881
+ controlnet_image = np.concatenate([controlnet_image, controlnet_image, controlnet_image], axis=2)
882
+
883
+ image = TF.to_tensor(image)
884
+
885
+ if masks is not None:
886
+ mask = masks[image_idx]
887
+ if isinstance(mask, str):
888
+ mask = Image.open(mask)
889
+ mask = mask.convert("L")
890
+ mask = mask.resize((1024, 1024))
891
+ elif isinstance(mask, Image.Image):
892
+ ...
893
+ else:
894
+ assert False
895
+ mask = TF.to_tensor(mask)
896
+
897
+ if controlnet == "SDXLControlNetPreEncodedControlnetCond":
898
+ image = image * (mask < 0.5)
899
+ image = TF.normalized(image, [0.5], [0.5])
900
+ image = vae.encode(image)
901
+ mask = TF.resize(mask, (1024 // 8, 1024 // 8))
902
+ image = torch.concat((image, mask))
903
+ else:
904
+ image = image * (mask < 0.5) + -1.0 * (mask >= 0.5)
905
+
906
+ images_.append(image)
907
+
908
+ images_ = torch.concat(images_)
909
+ else:
910
+ images_ = None
911
+
912
+ x_0 = sdxl_diffusion_loop(
913
+ prompts=[prompt] * num_images,
914
+ negative_prompts=[negative_prompt] * num_images,
915
+ unet=unet,
916
+ text_encoder_one=text_encoder_one,
917
+ text_encoder_two=text_encoder_two,
918
+ sigmas=sigmas,
919
+ timesteps=timesteps,
920
+ controlnet=controlnet,
921
+ adapter=adapter,
922
+ images=images_,
923
+ )
924
+
925
+ x_0 = vae.decode(x_0)
926
+ x_0 = vae.output_tensor_to_pil(x_0)
927
+
928
+ return x_0
929
+
930
+
931
+ if __name__ == "__main__":
932
+ from argparse import ArgumentParser
933
+
934
+ args = ArgumentParser()
935
+ args.add_argument("--prompt", required=True, type=str)
936
+ args.add_argument("--num_images", required=True, type=int, default=1)
937
+ args.add_argument("--num_inference_steps", required=False, type=int, default=50)
938
+ args.add_argument("--image", required=False, type=str, default=None)
939
+ args.add_argument("--mask", required=False, type=str, default=None)
940
+ args.add_argument("--controlnet_checkpoint", required=False, type=str, default=None)
941
+ args.add_argument("--controlnet", required=False, choices=["SDXLControlNet", "SDXLControlNetFull", "SDXLControNetPreEncodedControlnetCond"], default=None)
942
+ args.add_argument("--adapter_checkpoint", required=False, type=str, default=None)
943
+ args.add_argument("--apply_conditioning", choices=["canny"], required=False, default=None)
944
+ args.add_argument("--device", required=False, default=None)
945
+ args = args.parse_args()
946
+
947
+ images = gen_sdxl_simplified_interface(
948
+ prompt=args.prompt,
949
+ num_images=args.num_images,
950
+ num_inference_steps=args.num_inference_steps,
951
+ images=[args.image],
952
+ masks=[args.mask],
953
+ controlnet_checkpoint=args.controlnet_checkpoint,
954
+ controlnet=args.controlnet,
955
+ adapter_checkpoint=args.adapter_checkpoint,
956
+ apply_conditioning=args.apply_conditioning,
957
+ device=args.device,
958
+ negative_prompt=known_negative_prompt,
959
+ )
960
+
961
+ for i, image in enumerate(images):
962
+ image.save(f"out_{i}.png")
sdxl_models.py ADDED
@@ -0,0 +1,1375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from typing import List, Optional
4
+
5
+ import safetensors.torch
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchvision.transforms.functional as TF
9
+ import xformers
10
+ from PIL import Image
11
+ from torch import nn
12
+
13
+
14
+ class ModelUtils:
15
+ @property
16
+ def dtype(self):
17
+ return next(self.parameters()).dtype
18
+
19
+ @property
20
+ def device(self):
21
+ return next(self.parameters()).device
22
+
23
+ @classmethod
24
+ def load(cls, load_from: str, device, overrides: Optional[List[str]] = None):
25
+ import load_state_dict_patch
26
+
27
+ load_from = [load_from]
28
+
29
+ load_from += overrides
30
+
31
+ state_dict = {}
32
+
33
+ for load_from_ in load_from:
34
+ if os.path.isdir(load_from_):
35
+ load_from_ = os.path.join(load_from_, "diffusion_pytorch_model.safetensors")
36
+
37
+ state_dict.update(safetensors.torch.load_file(load_from_, device=device))
38
+
39
+ with torch.device("meta"):
40
+ model = cls()
41
+
42
+ model.load_state_dict(state_dict, assign=True)
43
+
44
+ return model
45
+
46
+
47
+ vae_scaling_factor = 0.13025
48
+
49
+
50
+ class SDXLVae(nn.Module, ModelUtils):
51
+ def __init__(self):
52
+ super().__init__()
53
+
54
+ # fmt: off
55
+
56
+ self.encoder = nn.ModuleDict(dict(
57
+ # 3 -> 128
58
+ conv_in=nn.Conv2d(3, 128, kernel_size=3, padding=1),
59
+
60
+ down_blocks=nn.ModuleList([
61
+ # 128 -> 128
62
+ nn.ModuleDict(dict(
63
+ resnets=nn.ModuleList([ResnetBlock2D(128, 128, eps=1e-6), ResnetBlock2D(128, 128, eps=1e-6)]),
64
+ downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)))]),
65
+ )),
66
+ # 128 -> 256
67
+ nn.ModuleDict(dict(
68
+ resnets=nn.ModuleList([ResnetBlock2D(128, 256, eps=1e-6), ResnetBlock2D(256, 256, eps=1e-6)]),
69
+ downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)))]),
70
+ )),
71
+ # 256 -> 512
72
+ nn.ModuleDict(dict(
73
+ resnets=nn.ModuleList([ResnetBlock2D(256, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
74
+ downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)))]),
75
+ )),
76
+ # 512 -> 512
77
+ nn.ModuleDict(dict(resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]))),
78
+ ]),
79
+
80
+ # 512 -> 512
81
+ mid_block=nn.ModuleDict(dict(
82
+ attentions=nn.ModuleList([Attention(512, 512, qkv_bias=True)]),
83
+ resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
84
+ )),
85
+
86
+ # 512 -> 8
87
+ conv_norm_out=nn.GroupNorm(32, 512, eps=1e-06),
88
+ conv_act=nn.SiLU(),
89
+ conv_out=nn.Conv2d(512, 8, kernel_size=3, padding=1)
90
+ ))
91
+
92
+ # 8 -> 8
93
+ self.quant_conv = nn.Conv2d(8, 8, kernel_size=1)
94
+
95
+ # 8 -> 4 from sampling mean and std
96
+
97
+ # 4 -> 4
98
+ self.post_quant_conv = nn.Conv2d(4, 4, 1)
99
+
100
+ self.decoder = nn.ModuleDict(dict(
101
+ # 4 -> 512
102
+ conv_in=nn.Conv2d(4, 512, kernel_size=3, padding=1),
103
+
104
+ # 512 -> 512
105
+ mid_block=nn.ModuleDict(dict(
106
+ attentions=nn.ModuleList([Attention(512, 512, qkv_bias=True)]),
107
+ resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
108
+ )),
109
+
110
+ up_blocks=nn.ModuleList([
111
+ # 512 -> 512
112
+ nn.ModuleDict(dict(
113
+ resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
114
+ upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)))]),
115
+ )),
116
+
117
+ # 512 -> 512
118
+ nn.ModuleDict(dict(
119
+ resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
120
+ upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)))]),
121
+ )),
122
+
123
+ # 512 -> 256
124
+ nn.ModuleDict(dict(
125
+ resnets=nn.ModuleList([ResnetBlock2D(512, 256, eps=1e-6), ResnetBlock2D(256, 256, eps=1e-6), ResnetBlock2D(256, 256, eps=1e-6)]),
126
+ upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)))]),
127
+ )),
128
+
129
+ # 256 -> 128
130
+ nn.ModuleDict(dict(
131
+ resnets=nn.ModuleList([ResnetBlock2D(256, 128, eps=1e-6), ResnetBlock2D(128, 128, eps=1e-6), ResnetBlock2D(128, 128, eps=1e-6)]),
132
+ )),
133
+ ]),
134
+
135
+ # 128 -> 3
136
+ conv_norm_out=nn.GroupNorm(32, 128, eps=1e-06),
137
+ conv_act=nn.SiLU(),
138
+ conv_out=nn.Conv2d(128, 3, kernel_size=3, padding=1)
139
+ ))
140
+
141
+ # fmt: on
142
+
143
+ def encode(self, x, generator=None):
144
+ h = x
145
+
146
+ h = self.encoder["conv_in"](h)
147
+
148
+ for down_block in self.encoder["down_blocks"]:
149
+ for resnet in down_block["resnets"]:
150
+ h = resnet(h)
151
+
152
+ if "downsamplers" in down_block:
153
+ h = down_block["downsamplers"][0]["conv"](h)
154
+
155
+ h = self.encoder["mid_block"]["resnets"][0](h)
156
+ h = self.encoder["mid_block"]["attentions"][0](h)
157
+ h = self.encoder["mid_block"]["resnets"][1](h)
158
+
159
+ h = self.encoder["conv_norm_out"](h)
160
+ h = self.encoder["conv_act"](h)
161
+ h = self.encoder["conv_out"](h)
162
+
163
+ mean, logvar = self.quant_conv(h).chunk(2, dim=1)
164
+
165
+ logvar = torch.clamp(logvar, -30.0, 20.0)
166
+
167
+ std = torch.exp(0.5 * logvar)
168
+
169
+ z = mean + torch.randn(mean.shape, device=mean.device, dtype=mean.dtype, generator=generator) * std
170
+
171
+ z = z * vae_scaling_factor
172
+
173
+ return z
174
+
175
+ def decode(self, z):
176
+ z = z / vae_scaling_factor
177
+
178
+ h = z
179
+
180
+ h = self.post_quant_conv(h)
181
+
182
+ h = self.decoder["mid_block"]["resnets"][0](h)
183
+ h = self.decoder["mid_block"]["attentions"][0](h)
184
+ h = self.decoder["mid_block"]["resnets"][1](h)
185
+
186
+ for up_block in self.encoder["up_blocks"]:
187
+ for resnet in up_block["resnets"]:
188
+ h = resnet(h)
189
+
190
+ if "upsamplers" in up_block:
191
+ h = up_block["upsamplers"][0]["conv"](h)
192
+
193
+ h = self.decoder["conv_norm_out"](h)
194
+ h = self.decoder["conv_act"](h)
195
+ h = self.decoder["conv_out"](h)
196
+
197
+ x_pred = h
198
+
199
+ return x_pred
200
+
201
+ @classmethod
202
+ def input_pil_to_tensor(self, x):
203
+ x = TF.to_tensor(x)
204
+ x = TF.normalize(x, [0.5], [0.5])
205
+ if x.ndim == 3:
206
+ x = x[None, :, :, :]
207
+ return x
208
+
209
+ @classmethod
210
+ def output_tensor_to_pil(self, x_pred):
211
+ x_pred = ((x_pred * 0.5 + 0.5).clamp(0, 1) * 255).to(torch.uint8).permute(0, 2, 3, 1)
212
+
213
+ x_pred = x_pred.permute(0, 2, 3, 1).cpu().numpy()
214
+
215
+ x_pred = [Image.fromarray(x) for x in x_pred]
216
+
217
+ return x_pred
218
+
219
+ @classmethod
220
+ def load_fp32(cls, device=None, overrides=None):
221
+ return cls.load("./weights/sdxl_vae.safetensors", device=device, overrides=overrides)
222
+
223
+ @classmethod
224
+ def load_fp16(cls, device=None, overrides=None):
225
+ return cls.load("./weights/sdxl_vae.fp16.safetensors", device=device, overrides=overrides)
226
+
227
+ @classmethod
228
+ def load_fp16_fix(cls, device=None, overrides=None):
229
+ return cls.load("./weights/sdxl_vae_fp16_fix.safetensors", device=device, overrides=overrides)
230
+
231
+
232
+ class SDXLUNet(nn.Module, ModelUtils):
233
+ def __init__(self):
234
+ super().__init__()
235
+
236
+ # fmt: off
237
+
238
+ encoder_hidden_states_dim = 2048
239
+
240
+ # timesteps embedding:
241
+
242
+ time_sinusoidal_embedding_dim = 320
243
+ time_embedding_dim = 1280
244
+
245
+ self.get_sinusoidal_timestep_embedding = lambda timesteps: get_sinusoidal_embedding(timesteps, time_sinusoidal_embedding_dim)
246
+
247
+ self.time_embedding = nn.ModuleDict(dict(
248
+ linear_1=nn.Linear(time_sinusoidal_embedding_dim, time_embedding_dim),
249
+ act=nn.SiLU(),
250
+ linear_2=nn.Linear(time_embedding_dim, time_embedding_dim),
251
+ ))
252
+
253
+ # image size and crop coordinates conditioning embedding (i.e. micro conditioning):
254
+
255
+ num_micro_conditioning_values = 6
256
+ micro_conditioning_embedding_dim = 256
257
+ additional_embedding_encoder_dim = 1280
258
+ self.get_sinusoidal_micro_conditioning_embedding = lambda micro_conditioning: get_sinusoidal_embedding(micro_conditioning, micro_conditioning_embedding_dim)
259
+
260
+ self.add_embedding = nn.ModuleDict(dict(
261
+ linear_1=nn.Linear(additional_embedding_encoder_dim + num_micro_conditioning_values * micro_conditioning_embedding_dim, time_embedding_dim),
262
+ act=nn.SiLU(),
263
+ linear_2=nn.Linear(time_embedding_dim, time_embedding_dim),
264
+ ))
265
+
266
+ # actual unet blocks:
267
+
268
+ self.conv_in = nn.Conv2d(4, 320, kernel_size=3, padding=1)
269
+
270
+ self.down_blocks = nn.ModuleList([
271
+ # 320 -> 320
272
+ nn.ModuleDict(dict(
273
+ resnets=nn.ModuleList([
274
+ ResnetBlock2D(320, 320, time_embedding_dim),
275
+ ResnetBlock2D(320, 320, time_embedding_dim),
276
+ ]),
277
+ downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)))]),
278
+ )),
279
+ # 320 -> 640
280
+ nn.ModuleDict(dict(
281
+ resnets=nn.ModuleList([
282
+ ResnetBlock2D(320, 640, time_embedding_dim),
283
+ ResnetBlock2D(640, 640, time_embedding_dim),
284
+ ]),
285
+ attentions=nn.ModuleList([
286
+ TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
287
+ TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
288
+ ]),
289
+ downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)))]),
290
+ )),
291
+ # 640 -> 1280
292
+ nn.ModuleDict(dict(
293
+ resnets=nn.ModuleList([
294
+ ResnetBlock2D(640, 1280, time_embedding_dim),
295
+ ResnetBlock2D(1280, 1280, time_embedding_dim),
296
+ ]),
297
+ attentions=nn.ModuleList([
298
+ TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
299
+ TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
300
+ ]),
301
+ )),
302
+ ])
303
+
304
+ self.mid_block = nn.ModuleDict(dict(
305
+ resnets=nn.ModuleList([
306
+ ResnetBlock2D(1280, 1280, time_embedding_dim),
307
+ ResnetBlock2D(1280, 1280, time_embedding_dim),
308
+ ]),
309
+ attentions=nn.ModuleList([TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10)]),
310
+ ))
311
+
312
+ self.up_blocks = nn.ModuleList([
313
+ # 1280 -> 1280
314
+ nn.ModuleDict(dict(
315
+ resnets=nn.ModuleList([
316
+ ResnetBlock2D(1280 + 1280, 1280, time_embedding_dim),
317
+ ResnetBlock2D(1280 + 1280, 1280, time_embedding_dim),
318
+ ResnetBlock2D(1280 + 640, 1280, time_embedding_dim),
319
+ ]),
320
+ attentions=nn.ModuleList([
321
+ TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
322
+ TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
323
+ TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
324
+ ]),
325
+ upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(1280, 1280, kernel_size=3, padding=1)))]),
326
+ )),
327
+ # 1280 -> 640
328
+ nn.ModuleDict(dict(
329
+ resnets=nn.ModuleList([
330
+ ResnetBlock2D(1280 + 640, 640, time_embedding_dim),
331
+ ResnetBlock2D(640 + 640, 640, time_embedding_dim),
332
+ ResnetBlock2D(640 + 320, 640, time_embedding_dim),
333
+ ]),
334
+ attentions=nn.ModuleList([
335
+ TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
336
+ TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
337
+ TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
338
+ ]),
339
+ upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, padding=1)))]),
340
+ )),
341
+ # 640 -> 320
342
+ nn.ModuleDict(dict(
343
+ resnets=nn.ModuleList([
344
+ ResnetBlock2D(640 + 320, 320, time_embedding_dim),
345
+ ResnetBlock2D(320 + 320, 320, time_embedding_dim),
346
+ ResnetBlock2D(320 + 320, 320, time_embedding_dim),
347
+ ]),
348
+ ))
349
+ ])
350
+
351
+ self.conv_norm_out = nn.GroupNorm(32, 320)
352
+ self.conv_act = nn.SiLU()
353
+ self.conv_out = nn.Conv2d(320, 4, kernel_size=3, padding=1)
354
+
355
+ # fmt: on
356
+
357
+ def forward(
358
+ self,
359
+ x_t,
360
+ t,
361
+ encoder_hidden_states,
362
+ micro_conditioning,
363
+ pooled_encoder_hidden_states,
364
+ down_block_additional_residuals: Optional[List[torch.Tensor]] = None,
365
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
366
+ add_to_down_block_inputs: Optional[List[torch.Tensor]] = None,
367
+ add_to_output: Optional[torch.Tensor] = None,
368
+ ):
369
+ hidden_state = x_t
370
+
371
+ t = self.get_sinusoidal_timestep_embedding(t)
372
+ t = t.to(dtype=hidden_state.dtype)
373
+ t = self.time_embedding["linear_1"](t)
374
+ t = self.time_embedding["act"](t)
375
+ t = self.time_embedding["linear_2"](t)
376
+
377
+ additional_conditioning = self.get_sinusoidal_micro_conditioning_embedding(micro_conditioning)
378
+ additional_conditioning = additional_conditioning.to(dtype=hidden_state.dtype)
379
+ additional_conditioning = additional_conditioning.flatten(1)
380
+ additional_conditioning = torch.concat([pooled_encoder_hidden_states, additional_conditioning], dim=-1)
381
+ additional_conditioning = self.add_embedding["linear_1"](additional_conditioning)
382
+ additional_conditioning = self.add_embedding["act"](additional_conditioning)
383
+ additional_conditioning = self.add_embedding["linear_2"](additional_conditioning)
384
+
385
+ t = t + additional_conditioning
386
+
387
+ hidden_state = self.conv_in(hidden_state)
388
+
389
+ residuals = [hidden_state]
390
+
391
+ for down_block in self.down_blocks:
392
+ for i, resnet in enumerate(down_block["resnets"]):
393
+ if add_to_down_block_inputs is not None:
394
+ hidden_state = hidden_state + add_to_down_block_inputs.pop(0)
395
+
396
+ hidden_state = resnet(hidden_state, t)
397
+
398
+ if "attentions" in down_block:
399
+ hidden_state = down_block["attentions"][i](hidden_state, encoder_hidden_states)
400
+
401
+ residuals.append(hidden_state)
402
+
403
+ if "downsamplers" in down_block:
404
+ if add_to_down_block_inputs is not None:
405
+ hidden_state = hidden_state + add_to_down_block_inputs.pop(0)
406
+
407
+ hidden_state = down_block["downsamplers"][0]["conv"](hidden_state)
408
+
409
+ residuals.append(hidden_state)
410
+
411
+ hidden_state = self.mid_block["resnets"][0](hidden_state, t)
412
+ hidden_state = self.mid_block["attentions"][0](hidden_state, encoder_hidden_states)
413
+ hidden_state = self.mid_block["resnets"][1](hidden_state, t)
414
+
415
+ if mid_block_additional_residual is not None:
416
+ hidden_state = hidden_state + mid_block_additional_residual
417
+
418
+ for up_block in self.up_blocks:
419
+ for i, resnet in enumerate(up_block["resnets"]):
420
+ residual = residuals.pop()
421
+
422
+ if down_block_additional_residuals is not None:
423
+ residual = residual + down_block_additional_residuals.pop()
424
+
425
+ hidden_state = torch.concat([hidden_state, residual], dim=1)
426
+
427
+ hidden_state = resnet(hidden_state, t)
428
+
429
+ if "attentions" in up_block:
430
+ hidden_state = up_block["attentions"][i](hidden_state, encoder_hidden_states)
431
+
432
+ if "upsamplers" in up_block:
433
+ hidden_state = F.interpolate(hidden_state, scale_factor=2.0, mode="nearest")
434
+ hidden_state = up_block["upsamplers"][0]["conv"](hidden_state)
435
+
436
+ hidden_state = self.conv_norm_out(hidden_state)
437
+ hidden_state = self.conv_act(hidden_state)
438
+ hidden_state = self.conv_out(hidden_state)
439
+
440
+ if add_to_output is not None:
441
+ hidden_state = hidden_state + add_to_output
442
+
443
+ eps_hat = hidden_state
444
+
445
+ return eps_hat
446
+
447
+ @classmethod
448
+ def load_fp32(cls, device=None, overrides=None):
449
+ return cls.load("./weights/sdxl_unet.safetensors", device=device, overrides=overrides)
450
+
451
+ @classmethod
452
+ def load_fp16(cls, device=None, overrides=None):
453
+ return cls.load("./weights/sdxl_unet.fp16.safetensors", device=device, overrides=overrides)
454
+
455
+
456
+ class SDXLControlNet(nn.Module, ModelUtils):
457
+ def __init__(self):
458
+ super().__init__()
459
+
460
+ # fmt: off
461
+
462
+ encoder_hidden_states_dim = 2048
463
+
464
+ # timesteps embedding:
465
+
466
+ time_sinusoidal_embedding_dim = 320
467
+ time_embedding_dim = 1280
468
+
469
+ self.get_sinusoidal_timestep_embedding = lambda timesteps: get_sinusoidal_embedding(timesteps, time_sinusoidal_embedding_dim)
470
+
471
+ self.time_embedding = nn.ModuleDict(dict(
472
+ linear_1=nn.Linear(time_sinusoidal_embedding_dim, time_embedding_dim),
473
+ act=nn.SiLU(),
474
+ linear_2=nn.Linear(time_embedding_dim, time_embedding_dim),
475
+ ))
476
+
477
+ # image size and crop coordinates conditioning embedding (i.e. micro conditioning):
478
+
479
+ num_micro_conditioning_values = 6
480
+ micro_conditioning_embedding_dim = 256
481
+ additional_embedding_encoder_dim = 1280
482
+ self.get_sinusoidal_micro_conditioning_embedding = lambda micro_conditioning: get_sinusoidal_embedding(micro_conditioning, micro_conditioning_embedding_dim)
483
+
484
+ self.add_embedding = nn.ModuleDict(dict(
485
+ linear_1=nn.Linear(additional_embedding_encoder_dim + num_micro_conditioning_values * micro_conditioning_embedding_dim, time_embedding_dim),
486
+ act=nn.SiLU(),
487
+ linear_2=nn.Linear(time_embedding_dim, time_embedding_dim),
488
+ ))
489
+
490
+ # controlnet cond embedding:
491
+ self.controlnet_cond_embedding = nn.ModuleDict(dict(
492
+ conv_in=nn.Conv2d(3, 16, kernel_size=3, padding=1),
493
+ blocks=nn.ModuleList([
494
+ # 16 -> 32
495
+ nn.Conv2d(16, 16, kernel_size=3, padding=1),
496
+ nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=2),
497
+ # 32 -> 96
498
+ nn.Conv2d(32, 32, kernel_size=3, padding=1),
499
+ nn.Conv2d(32, 96, kernel_size=3, padding=1, stride=2),
500
+ # 96 -> 256
501
+ nn.Conv2d(96, 96, kernel_size=3, padding=1),
502
+ nn.Conv2d(96, 256, kernel_size=3, padding=1, stride=2),
503
+ ]),
504
+ conv_out=zero_module(nn.Conv2d(256, 320, kernel_size=3, padding=1)),
505
+ ))
506
+
507
+ # actual unet blocks:
508
+
509
+ self.conv_in = nn.Conv2d(4, 320, kernel_size=3, padding=1)
510
+
511
+ self.down_blocks = nn.ModuleList([
512
+ # 320 -> 320
513
+ nn.ModuleDict(dict(
514
+ resnets=nn.ModuleList([
515
+ ResnetBlock2D(320, 320, time_embedding_dim),
516
+ ResnetBlock2D(320, 320, time_embedding_dim),
517
+ ]),
518
+ downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)))]),
519
+ )),
520
+ # 320 -> 640
521
+ nn.ModuleDict(dict(
522
+ resnets=nn.ModuleList([
523
+ ResnetBlock2D(320, 640, time_embedding_dim),
524
+ ResnetBlock2D(640, 640, time_embedding_dim),
525
+ ]),
526
+ attentions=nn.ModuleList([
527
+ TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
528
+ TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
529
+ ]),
530
+ downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)))]),
531
+ )),
532
+ # 640 -> 1280
533
+ nn.ModuleDict(dict(
534
+ resnets=nn.ModuleList([
535
+ ResnetBlock2D(640, 1280, time_embedding_dim),
536
+ ResnetBlock2D(1280, 1280, time_embedding_dim),
537
+ ]),
538
+ attentions=nn.ModuleList([
539
+ TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
540
+ TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
541
+ ]),
542
+ )),
543
+ ])
544
+
545
+ self.controlnet_down_blocks = nn.ModuleList([
546
+ zero_module(nn.Conv2d(320, 320, kernel_size=1)),
547
+ zero_module(nn.Conv2d(320, 320, kernel_size=1)),
548
+ zero_module(nn.Conv2d(320, 320, kernel_size=1)),
549
+ zero_module(nn.Conv2d(320, 320, kernel_size=1)),
550
+ zero_module(nn.Conv2d(640, 640, kernel_size=1)),
551
+ zero_module(nn.Conv2d(640, 640, kernel_size=1)),
552
+ zero_module(nn.Conv2d(640, 640, kernel_size=1)),
553
+ zero_module(nn.Conv2d(1280, 1280, kernel_size=1)),
554
+ zero_module(nn.Conv2d(1280, 1280, kernel_size=1)),
555
+ ])
556
+
557
+ self.mid_block = nn.ModuleDict(dict(
558
+ resnets=nn.ModuleList([
559
+ ResnetBlock2D(1280, 1280, time_embedding_dim),
560
+ ResnetBlock2D(1280, 1280, time_embedding_dim),
561
+ ]),
562
+ attentions=nn.ModuleList([TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10)]),
563
+ ))
564
+
565
+ self.controlnet_mid_block = zero_module(nn.Conv2d(1280, 1280, kernel_size=1))
566
+
567
+ # fmt: on
568
+
569
+ def forward(
570
+ self,
571
+ x_t,
572
+ t,
573
+ encoder_hidden_states,
574
+ micro_conditioning,
575
+ pooled_encoder_hidden_states,
576
+ controlnet_cond,
577
+ ):
578
+ hidden_state = x_t
579
+
580
+ t = self.get_sinusoidal_timestep_embedding(t)
581
+ t = t.to(dtype=hidden_state.dtype)
582
+ t = self.time_embedding["linear_1"](t)
583
+ t = self.time_embedding["act"](t)
584
+ t = self.time_embedding["linear_2"](t)
585
+
586
+ additional_conditioning = self.get_sinusoidal_micro_conditioning_embedding(micro_conditioning)
587
+ additional_conditioning = additional_conditioning.to(dtype=hidden_state.dtype)
588
+ additional_conditioning = additional_conditioning.flatten(1)
589
+ additional_conditioning = torch.concat([pooled_encoder_hidden_states, additional_conditioning], dim=-1)
590
+ additional_conditioning = self.add_embedding["linear_1"](additional_conditioning)
591
+ additional_conditioning = self.add_embedding["act"](additional_conditioning)
592
+ additional_conditioning = self.add_embedding["linear_2"](additional_conditioning)
593
+
594
+ t = t + additional_conditioning
595
+
596
+ controlnet_cond = self.controlnet_cond_embedding["conv_in"](controlnet_cond)
597
+ controlnet_cond = F.silu(controlnet_cond)
598
+
599
+ for block in self.controlnet_cond_embedding["blocks"]:
600
+ controlnet_cond = F.silu(block(controlnet_cond))
601
+
602
+ controlnet_cond = self.controlnet_cond_embedding["conv_out"](controlnet_cond)
603
+
604
+ hidden_state = self.conv_in(hidden_state)
605
+
606
+ hidden_state = hidden_state + controlnet_cond
607
+
608
+ down_block_res_sample = self.controlnet_down_blocks[0](hidden_state)
609
+ down_block_res_samples = [down_block_res_sample]
610
+
611
+ for down_block in self.down_blocks:
612
+ for i, resnet in enumerate(down_block["resnets"]):
613
+ hidden_state = resnet(hidden_state, t)
614
+
615
+ if "attentions" in down_block:
616
+ hidden_state = down_block["attentions"][i](hidden_state, encoder_hidden_states)
617
+
618
+ down_block_res_sample = self.controlnet_down_blocks[len(down_block_res_samples)](hidden_state)
619
+ down_block_res_samples.append(down_block_res_sample)
620
+
621
+ if "downsamplers" in down_block:
622
+ hidden_state = down_block["downsamplers"][0]["conv"](hidden_state)
623
+
624
+ down_block_res_sample = self.controlnet_down_blocks[len(down_block_res_samples)](hidden_state)
625
+ down_block_res_samples.append(down_block_res_sample)
626
+
627
+ hidden_state = self.mid_block["resnets"][0](hidden_state, t)
628
+ hidden_state = self.mid_block["attentions"][0](hidden_state, encoder_hidden_states)
629
+ hidden_state = self.mid_block["resnets"][1](hidden_state, t)
630
+
631
+ mid_block_res_sample = self.controlnet_mid_block(hidden_state)
632
+
633
+ return dict(
634
+ down_block_res_samples=down_block_res_samples,
635
+ mid_block_res_sample=mid_block_res_sample,
636
+ )
637
+
638
+ @classmethod
639
+ def from_unet(cls, unet):
640
+ controlnet = cls()
641
+
642
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
643
+ controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
644
+
645
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
646
+
647
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
648
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
649
+
650
+ return controlnet
651
+
652
+
653
+ class SDXLControlNetPreEncodedControlnetCond(nn.Module, ModelUtils):
654
+ def __init__(self):
655
+ super().__init__()
656
+
657
+ # fmt: off
658
+
659
+ encoder_hidden_states_dim = 2048
660
+
661
+ # timesteps embedding:
662
+
663
+ time_sinusoidal_embedding_dim = 320
664
+ time_embedding_dim = 1280
665
+
666
+ self.get_sinusoidal_timestep_embedding = lambda timesteps: get_sinusoidal_embedding(timesteps, time_sinusoidal_embedding_dim)
667
+
668
+ self.time_embedding = nn.ModuleDict(dict(
669
+ linear_1=nn.Linear(time_sinusoidal_embedding_dim, time_embedding_dim),
670
+ act=nn.SiLU(),
671
+ linear_2=nn.Linear(time_embedding_dim, time_embedding_dim),
672
+ ))
673
+
674
+ # image size and crop coordinates conditioning embedding (i.e. micro conditioning):
675
+
676
+ num_micro_conditioning_values = 6
677
+ micro_conditioning_embedding_dim = 256
678
+ additional_embedding_encoder_dim = 1280
679
+ self.get_sinusoidal_micro_conditioning_embedding = lambda micro_conditioning: get_sinusoidal_embedding(micro_conditioning, micro_conditioning_embedding_dim)
680
+
681
+ self.add_embedding = nn.ModuleDict(dict(
682
+ linear_1=nn.Linear(additional_embedding_encoder_dim + num_micro_conditioning_values * micro_conditioning_embedding_dim, time_embedding_dim),
683
+ act=nn.SiLU(),
684
+ linear_2=nn.Linear(time_embedding_dim, time_embedding_dim),
685
+ ))
686
+
687
+ # actual unet blocks:
688
+
689
+ # unet latents: 4 +
690
+ # control image latents: 4 +
691
+ # controlnet_mask: 1
692
+ # = 9 channels
693
+ self.conv_in = nn.Conv2d(9, 320, kernel_size=3, padding=1)
694
+
695
+ self.down_blocks = nn.ModuleList([
696
+ # 320 -> 320
697
+ nn.ModuleDict(dict(
698
+ resnets=nn.ModuleList([
699
+ ResnetBlock2D(320, 320, time_embedding_dim),
700
+ ResnetBlock2D(320, 320, time_embedding_dim),
701
+ ]),
702
+ downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)))]),
703
+ )),
704
+ # 320 -> 640
705
+ nn.ModuleDict(dict(
706
+ resnets=nn.ModuleList([
707
+ ResnetBlock2D(320, 640, time_embedding_dim),
708
+ ResnetBlock2D(640, 640, time_embedding_dim),
709
+ ]),
710
+ attentions=nn.ModuleList([
711
+ TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
712
+ TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
713
+ ]),
714
+ downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)))]),
715
+ )),
716
+ # 640 -> 1280
717
+ nn.ModuleDict(dict(
718
+ resnets=nn.ModuleList([
719
+ ResnetBlock2D(640, 1280, time_embedding_dim),
720
+ ResnetBlock2D(1280, 1280, time_embedding_dim),
721
+ ]),
722
+ attentions=nn.ModuleList([
723
+ TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
724
+ TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
725
+ ]),
726
+ )),
727
+ ])
728
+
729
+ self.controlnet_down_blocks = nn.ModuleList([
730
+ zero_module(nn.Conv2d(320, 320, kernel_size=1)),
731
+ zero_module(nn.Conv2d(320, 320, kernel_size=1)),
732
+ zero_module(nn.Conv2d(320, 320, kernel_size=1)),
733
+ zero_module(nn.Conv2d(320, 320, kernel_size=1)),
734
+ zero_module(nn.Conv2d(640, 640, kernel_size=1)),
735
+ zero_module(nn.Conv2d(640, 640, kernel_size=1)),
736
+ zero_module(nn.Conv2d(640, 640, kernel_size=1)),
737
+ zero_module(nn.Conv2d(1280, 1280, kernel_size=1)),
738
+ zero_module(nn.Conv2d(1280, 1280, kernel_size=1)),
739
+ ])
740
+
741
+ self.mid_block = nn.ModuleDict(dict(
742
+ resnets=nn.ModuleList([
743
+ ResnetBlock2D(1280, 1280, time_embedding_dim),
744
+ ResnetBlock2D(1280, 1280, time_embedding_dim),
745
+ ]),
746
+ attentions=nn.ModuleList([TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10)]),
747
+ ))
748
+
749
+ self.controlnet_mid_block = zero_module(nn.Conv2d(1280, 1280, kernel_size=1))
750
+
751
+ # fmt: on
752
+
753
+ def forward(
754
+ self,
755
+ x_t,
756
+ t,
757
+ encoder_hidden_states,
758
+ micro_conditioning,
759
+ pooled_encoder_hidden_states,
760
+ controlnet_cond,
761
+ ):
762
+ hidden_state = x_t
763
+
764
+ t = self.get_sinusoidal_timestep_embedding(t)
765
+ t = t.to(dtype=hidden_state.dtype)
766
+ t = self.time_embedding["linear_1"](t)
767
+ t = self.time_embedding["act"](t)
768
+ t = self.time_embedding["linear_2"](t)
769
+
770
+ additional_conditioning = self.get_sinusoidal_micro_conditioning_embedding(micro_conditioning)
771
+ additional_conditioning = additional_conditioning.to(dtype=hidden_state.dtype)
772
+ additional_conditioning = additional_conditioning.flatten(1)
773
+ additional_conditioning = torch.concat([pooled_encoder_hidden_states, additional_conditioning], dim=-1)
774
+ additional_conditioning = self.add_embedding["linear_1"](additional_conditioning)
775
+ additional_conditioning = self.add_embedding["act"](additional_conditioning)
776
+ additional_conditioning = self.add_embedding["linear_2"](additional_conditioning)
777
+
778
+ t = t + additional_conditioning
779
+
780
+ hidden_state = torch.concat((hidden_state, controlnet_cond), dim=1)
781
+
782
+ hidden_state = self.conv_in(hidden_state)
783
+
784
+ down_block_res_sample = self.controlnet_down_blocks[0](hidden_state)
785
+ down_block_res_samples = [down_block_res_sample]
786
+
787
+ for down_block in self.down_blocks:
788
+ for i, resnet in enumerate(down_block["resnets"]):
789
+ hidden_state = resnet(hidden_state, t)
790
+
791
+ if "attentions" in down_block:
792
+ hidden_state = down_block["attentions"][i](hidden_state, encoder_hidden_states)
793
+
794
+ down_block_res_sample = self.controlnet_down_blocks[len(down_block_res_samples)](hidden_state)
795
+ down_block_res_samples.append(down_block_res_sample)
796
+
797
+ if "downsamplers" in down_block:
798
+ hidden_state = down_block["downsamplers"][0]["conv"](hidden_state)
799
+
800
+ down_block_res_sample = self.controlnet_down_blocks[len(down_block_res_samples)](hidden_state)
801
+ down_block_res_samples.append(down_block_res_sample)
802
+
803
+ hidden_state = self.mid_block["resnets"][0](hidden_state, t)
804
+ hidden_state = self.mid_block["attentions"][0](hidden_state, encoder_hidden_states)
805
+ hidden_state = self.mid_block["resnets"][1](hidden_state, t)
806
+
807
+ mid_block_res_sample = self.controlnet_mid_block(hidden_state)
808
+
809
+ return dict(
810
+ down_block_res_samples=down_block_res_samples,
811
+ mid_block_res_sample=mid_block_res_sample,
812
+ )
813
+
814
+ @classmethod
815
+ def from_unet(cls, unet):
816
+ controlnet = cls()
817
+
818
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
819
+ controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
820
+
821
+ conv_in_weight = unet.conv_in.state_dict()["weight"]
822
+ padding = torch.zeros((320, 5, 3, 3), device=conv_in_weight.device, dtype=conv_in_weight.dtype)
823
+ conv_in_weight = torch.concat((conv_in_weight, padding), dim=1)
824
+
825
+ conv_in_bias = unet.conv_in.state_dict()["bias"]
826
+
827
+ controlnet.conv_in.load_state_dict({"weight": conv_in_weight, "bias": conv_in_bias})
828
+
829
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
830
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
831
+
832
+ return controlnet
833
+
834
+
835
+ class SDXLControlNetFull(nn.Module, ModelUtils):
836
+ def __init__(self):
837
+ super().__init__()
838
+
839
+ # fmt: off
840
+
841
+ encoder_hidden_states_dim = 2048
842
+
843
+ # timesteps embedding:
844
+
845
+ time_sinusoidal_embedding_dim = 320
846
+ time_embedding_dim = 1280
847
+
848
+ self.get_sinusoidal_timestep_embedding = lambda timesteps: get_sinusoidal_embedding(timesteps, time_sinusoidal_embedding_dim)
849
+
850
+ self.time_embedding = nn.ModuleDict(dict(
851
+ linear_1=nn.Linear(time_sinusoidal_embedding_dim, time_embedding_dim),
852
+ act=nn.SiLU(),
853
+ linear_2=nn.Linear(time_embedding_dim, time_embedding_dim),
854
+ ))
855
+
856
+ # image size and crop coordinates conditioning embedding (i.e. micro conditioning):
857
+
858
+ num_micro_conditioning_values = 6
859
+ micro_conditioning_embedding_dim = 256
860
+ additional_embedding_encoder_dim = 1280
861
+ self.get_sinusoidal_micro_conditioning_embedding = lambda micro_conditioning: get_sinusoidal_embedding(micro_conditioning, micro_conditioning_embedding_dim)
862
+
863
+ self.add_embedding = nn.ModuleDict(dict(
864
+ linear_1=nn.Linear(additional_embedding_encoder_dim + num_micro_conditioning_values * micro_conditioning_embedding_dim, time_embedding_dim),
865
+ act=nn.SiLU(),
866
+ linear_2=nn.Linear(time_embedding_dim, time_embedding_dim),
867
+ ))
868
+
869
+ # controlnet cond embedding:
870
+ self.controlnet_cond_embedding = nn.ModuleDict(dict(
871
+ conv_in=nn.Conv2d(3, 16, kernel_size=3, padding=1),
872
+ blocks=nn.ModuleList([
873
+ # 16 -> 32
874
+ nn.Conv2d(16, 16, kernel_size=3, padding=1),
875
+ nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=2),
876
+ # 32 -> 96
877
+ nn.Conv2d(32, 32, kernel_size=3, padding=1),
878
+ nn.Conv2d(32, 96, kernel_size=3, padding=1, stride=2),
879
+ # 96 -> 256
880
+ nn.Conv2d(96, 96, kernel_size=3, padding=1),
881
+ nn.Conv2d(96, 256, kernel_size=3, padding=1, stride=2),
882
+ ]),
883
+ conv_out=zero_module(nn.Conv2d(256, 320, kernel_size=3, padding=1)),
884
+ ))
885
+
886
+ # actual unet blocks:
887
+
888
+ self.conv_in = nn.Conv2d(4, 320, kernel_size=3, padding=1)
889
+
890
+ self.down_blocks = nn.ModuleList([
891
+ # 320 -> 320
892
+ nn.ModuleDict(dict(
893
+ resnets=nn.ModuleList([
894
+ ResnetBlock2D(320, 320, time_embedding_dim),
895
+ ResnetBlock2D(320, 320, time_embedding_dim),
896
+ ]),
897
+ downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)))]),
898
+ )),
899
+ # 320 -> 640
900
+ nn.ModuleDict(dict(
901
+ resnets=nn.ModuleList([
902
+ ResnetBlock2D(320, 640, time_embedding_dim),
903
+ ResnetBlock2D(640, 640, time_embedding_dim),
904
+ ]),
905
+ attentions=nn.ModuleList([
906
+ TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
907
+ TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
908
+ ]),
909
+ downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)))]),
910
+ )),
911
+ # 640 -> 1280
912
+ nn.ModuleDict(dict(
913
+ resnets=nn.ModuleList([
914
+ ResnetBlock2D(640, 1280, time_embedding_dim),
915
+ ResnetBlock2D(1280, 1280, time_embedding_dim),
916
+ ]),
917
+ attentions=nn.ModuleList([
918
+ TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
919
+ TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
920
+ ]),
921
+ )),
922
+ ])
923
+
924
+ self.controlnet_down_blocks = nn.ModuleList([
925
+ zero_module(nn.Conv2d(320, 320, kernel_size=1)),
926
+ zero_module(nn.Conv2d(320, 320, kernel_size=1)),
927
+ zero_module(nn.Conv2d(320, 320, kernel_size=1)),
928
+ zero_module(nn.Conv2d(320, 320, kernel_size=1)),
929
+ zero_module(nn.Conv2d(640, 640, kernel_size=1)),
930
+ zero_module(nn.Conv2d(640, 640, kernel_size=1)),
931
+ zero_module(nn.Conv2d(640, 640, kernel_size=1)),
932
+ zero_module(nn.Conv2d(1280, 1280, kernel_size=1)),
933
+ ])
934
+
935
+ self.mid_block = nn.ModuleDict(dict(
936
+ resnets=nn.ModuleList([
937
+ ResnetBlock2D(1280, 1280, time_embedding_dim),
938
+ ResnetBlock2D(1280, 1280, time_embedding_dim),
939
+ ]),
940
+ attentions=nn.ModuleList([TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10)]),
941
+ ))
942
+
943
+ self.controlnet_mid_block = zero_module(nn.Conv2d(1280, 1280, kernel_size=1))
944
+
945
+ self.up_blocks = nn.ModuleList([
946
+ # 1280 -> 1280
947
+ nn.ModuleDict(dict(
948
+ resnets=nn.ModuleList([
949
+ ResnetBlock2D(1280 + 1280, 1280, time_embedding_dim),
950
+ ResnetBlock2D(1280 + 1280, 1280, time_embedding_dim),
951
+ ResnetBlock2D(1280 + 640, 1280, time_embedding_dim),
952
+ ]),
953
+ attentions=nn.ModuleList([
954
+ TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
955
+ TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
956
+ TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
957
+ ]),
958
+ upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(1280, 1280, kernel_size=3, padding=1)))]),
959
+ )),
960
+ # 1280 -> 640
961
+ nn.ModuleDict(dict(
962
+ resnets=nn.ModuleList([
963
+ ResnetBlock2D(1280 + 640, 640, time_embedding_dim),
964
+ ResnetBlock2D(640 + 640, 640, time_embedding_dim),
965
+ ResnetBlock2D(640 + 320, 640, time_embedding_dim),
966
+ ]),
967
+ attentions=nn.ModuleList([
968
+ TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
969
+ TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
970
+ TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
971
+ ]),
972
+ upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, padding=1)))]),
973
+ )),
974
+ # 640 -> 320
975
+ nn.ModuleDict(dict(
976
+ resnets=nn.ModuleList([
977
+ ResnetBlock2D(640 + 320, 320, time_embedding_dim),
978
+ ResnetBlock2D(320 + 320, 320, time_embedding_dim),
979
+ ResnetBlock2D(320 + 320, 320, time_embedding_dim),
980
+ ]),
981
+ ))
982
+ ])
983
+
984
+ # take the output of transformer(resnet(hidden_states)) and project it to
985
+ # the number of residual channels for the same block
986
+ self.controlnet_up_blocks = nn.ModuleList([
987
+ zero_module(nn.Conv2d(1280, 1280, kernel_size=1)),
988
+ zero_module(nn.Conv2d(1280, 1280, kernel_size=1)),
989
+ zero_module(nn.Conv2d(1280, 640, kernel_size=1)),
990
+ zero_module(nn.Conv2d(640, 640, kernel_size=1)),
991
+ zero_module(nn.Conv2d(640, 640, kernel_size=1)),
992
+ zero_module(nn.Conv2d(640, 320, kernel_size=1)),
993
+ zero_module(nn.Conv2d(320, 320, kernel_size=1)),
994
+ zero_module(nn.Conv2d(320, 320, kernel_size=1)),
995
+ zero_module(nn.Conv2d(320, 320, kernel_size=1)),
996
+ ])
997
+
998
+ self.conv_norm_out = nn.GroupNorm(32, 320)
999
+ self.conv_act = nn.SiLU()
1000
+ self.conv_out = nn.Conv2d(320, 4, kernel_size=3, padding=1)
1001
+
1002
+ self.controlnet_conv_out = zero_module(nn.Conv2d(4, 4, kernel_size=1))
1003
+
1004
+ # fmt: on
1005
+
1006
+ def forward(
1007
+ self,
1008
+ x_t,
1009
+ t,
1010
+ encoder_hidden_states,
1011
+ micro_conditioning,
1012
+ pooled_encoder_hidden_states,
1013
+ controlnet_cond,
1014
+ ):
1015
+ hidden_state = x_t
1016
+
1017
+ t = self.get_sinusoidal_timestep_embedding(t)
1018
+ t = t.to(dtype=hidden_state.dtype)
1019
+ t = self.time_embedding["linear_1"](t)
1020
+ t = self.time_embedding["act"](t)
1021
+ t = self.time_embedding["linear_2"](t)
1022
+
1023
+ additional_conditioning = self.get_sinusoidal_micro_conditioning_embedding(micro_conditioning)
1024
+ additional_conditioning = additional_conditioning.to(dtype=hidden_state.dtype)
1025
+ additional_conditioning = additional_conditioning.flatten(1)
1026
+ additional_conditioning = torch.concat([pooled_encoder_hidden_states, additional_conditioning], dim=-1)
1027
+ additional_conditioning = self.add_embedding["linear_1"](additional_conditioning)
1028
+ additional_conditioning = self.add_embedding["act"](additional_conditioning)
1029
+ additional_conditioning = self.add_embedding["linear_2"](additional_conditioning)
1030
+
1031
+ t = t + additional_conditioning
1032
+
1033
+ controlnet_cond = self.controlnet_cond_embedding["conv_in"](controlnet_cond)
1034
+ controlnet_cond = F.silu(controlnet_cond)
1035
+
1036
+ for block in self.controlnet_cond_embedding["blocks"]:
1037
+ controlnet_cond = F.silu(block(controlnet_cond))
1038
+
1039
+ controlnet_cond = self.controlnet_cond_embedding["conv_out"](controlnet_cond)
1040
+
1041
+ hidden_state = self.conv_in(hidden_state)
1042
+
1043
+ hidden_state = hidden_state + controlnet_cond
1044
+
1045
+ residuals = [hidden_state]
1046
+
1047
+ add_to_down_block_input = self.controlnet_down_blocks[0](hidden_state)
1048
+ add_to_down_block_inputs = [add_to_down_block_input]
1049
+
1050
+ for down_block in self.down_blocks:
1051
+ for i, resnet in enumerate(down_block["resnets"]):
1052
+ hidden_state = resnet(hidden_state, t)
1053
+
1054
+ if "attentions" in down_block:
1055
+ hidden_state = down_block["attentions"][i](hidden_state, encoder_hidden_states)
1056
+
1057
+ if len(add_to_down_block_inputs) < len(self.controlnet_down_blocks):
1058
+ add_to_down_block_input = self.controlnet_down_blocks[len(add_to_down_block_inputs)](hidden_state)
1059
+ add_to_down_block_inputs.append(add_to_down_block_input)
1060
+
1061
+ residuals.append(hidden_state)
1062
+
1063
+ if "downsamplers" in down_block:
1064
+ hidden_state = down_block["downsamplers"][0]["conv"](hidden_state)
1065
+
1066
+ if len(add_to_down_block_inputs) < len(self.controlnet_down_blocks):
1067
+ add_to_down_block_input = self.controlnet_down_blocks[len(add_to_down_block_inputs)](hidden_state)
1068
+ add_to_down_block_inputs.append(add_to_down_block_input)
1069
+
1070
+ residuals.append(hidden_state)
1071
+
1072
+ hidden_state = self.mid_block["resnets"][0](hidden_state, t)
1073
+ hidden_state = self.mid_block["attentions"][0](hidden_state, encoder_hidden_states)
1074
+ hidden_state = self.mid_block["resnets"][1](hidden_state, t)
1075
+
1076
+ mid_block_res_sample = self.controlnet_mid_block(hidden_state)
1077
+
1078
+ down_block_res_samples = []
1079
+
1080
+ for up_block in self.up_blocks:
1081
+ for i, resnet in enumerate(up_block["resnets"]):
1082
+ residual = residuals.pop()
1083
+
1084
+ hidden_state = torch.concat([hidden_state, residual], dim=1)
1085
+
1086
+ hidden_state = resnet(hidden_state, t)
1087
+
1088
+ if "attentions" in up_block:
1089
+ hidden_state = up_block["attentions"][i](hidden_state, encoder_hidden_states)
1090
+
1091
+ down_block_res_sample = self.controlnet_up_blocks[len(down_block_res_samples)](hidden_state)
1092
+ down_block_res_samples.insert(0, down_block_res_sample)
1093
+
1094
+ if "upsamplers" in up_block:
1095
+ hidden_state = F.interpolate(hidden_state, scale_factor=2.0, mode="nearest")
1096
+ hidden_state = up_block["upsamplers"][0]["conv"](hidden_state)
1097
+
1098
+ hidden_state = self.conv_norm_out(hidden_state)
1099
+ hidden_state = self.conv_act(hidden_state)
1100
+ hidden_state = self.conv_out(hidden_state)
1101
+
1102
+ add_to_output = self.controlnet_conv_out(hidden_state)
1103
+
1104
+ return dict(
1105
+ down_block_res_samples=down_block_res_samples,
1106
+ mid_block_res_sample=mid_block_res_sample,
1107
+ add_to_down_block_inputs=add_to_down_block_inputs,
1108
+ add_to_output=add_to_output,
1109
+ )
1110
+
1111
+ @classmethod
1112
+ def from_unet(cls, unet):
1113
+ controlnet = cls()
1114
+
1115
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
1116
+ controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
1117
+
1118
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
1119
+
1120
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
1121
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
1122
+ controlnet.up_blocks.load_state_dict(unet.up_blocks.state_dict())
1123
+
1124
+ controlnet.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict())
1125
+ controlnet.conv_out.load_state_dict(unet.conv_out.state_dict())
1126
+
1127
+ return controlnet
1128
+
1129
+
1130
+ class SDXLAdapter(nn.Module, ModelUtils):
1131
+ def __init__(self):
1132
+ super().__init__()
1133
+
1134
+ # fmt: off
1135
+
1136
+ self.adapter = nn.ModuleDict(dict(
1137
+ # 3 -> 768
1138
+ unshuffle=nn.PixelUnshuffle(16),
1139
+
1140
+ # 768 -> 320
1141
+ conv_in=nn.Conv2d(768, 320, kernel_size=3, padding=1),
1142
+
1143
+ body=nn.ModuleList([
1144
+ # 320 -> 320
1145
+ nn.ModuleDict(dict(
1146
+ resnets=nn.ModuleList(
1147
+ nn.ModuleDict(dict(block1=nn.Conv2d(320, 320, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(320, 320, kernel_size=1))),
1148
+ nn.ModuleDict(dict(block1=nn.Conv2d(320, 320, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(320, 320, kernel_size=1))),
1149
+ )
1150
+ )),
1151
+ # 320 -> 640
1152
+ nn.ModuleDict(dict(
1153
+ in_conv=nn.Conv2d(320, 640, kernel_size=1),
1154
+ resnets=nn.ModuleList(
1155
+ nn.ModuleDict(dict(block1=nn.Conv2d(640, 640, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(640, 640, kernel_size=1))),
1156
+ nn.ModuleDict(dict(block1=nn.Conv2d(640, 640, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(640, 640, kernel_size=1))),
1157
+ )
1158
+ )),
1159
+ # 640 -> 1280
1160
+ nn.ModuleDict(dict(
1161
+ downsample=nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
1162
+ in_conv=nn.Conv2d(640, 1280, kernel_size=1),
1163
+ resnets=nn.ModuleList(
1164
+ nn.ModuleDict(dict(block1=nn.Conv2d(1280, 1280, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(1280, 1280, kernel_size=1))),
1165
+ nn.ModuleDict(dict(block1=nn.Conv2d(1280, 1280, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(1280, 1280, kernel_size=1))),
1166
+ )
1167
+ )),
1168
+ # 1280 -> 1280
1169
+ nn.ModuleDict(dict(
1170
+ resnets=nn.ModuleList(
1171
+ nn.ModuleDict(dict(block1=nn.Conv2d(1280, 1280, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(1280, 1280, kernel_size=1))),
1172
+ nn.ModuleDict(dict(block1=nn.Conv2d(1280, 1280, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(1280, 1280, kernel_size=1))),
1173
+ )
1174
+ )),
1175
+ ])
1176
+ ))
1177
+
1178
+ # fmt: on
1179
+
1180
+ def forward(self, x):
1181
+ x = self.unshuffle(x)
1182
+ x = self.conv_in(x)
1183
+
1184
+ features = []
1185
+
1186
+ for block in self.body:
1187
+ if "downsample" in block:
1188
+ x = block["downsample"](x)
1189
+
1190
+ if "in_conv" in block:
1191
+ x = block["in_conv"](x)
1192
+
1193
+ for resnet in block["resnets"]:
1194
+ residual = x
1195
+ x = resnet["block1"](x)
1196
+ x = resnet["act"](x)
1197
+ x = resnet["block2"](x)
1198
+ x = residual + x
1199
+
1200
+ features.append(x)
1201
+
1202
+ return features
1203
+
1204
+
1205
+ def get_sinusoidal_embedding(
1206
+ indices: torch.Tensor,
1207
+ embedding_dim: int,
1208
+ ):
1209
+ half_dim = embedding_dim // 2
1210
+ exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=indices.device)
1211
+ exponent = exponent / half_dim
1212
+
1213
+ emb = torch.exp(exponent)
1214
+ emb = indices.unsqueeze(-1).float() * emb
1215
+ emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
1216
+
1217
+ return emb
1218
+
1219
+
1220
+ class ResnetBlock2D(nn.Module):
1221
+ def __init__(self, in_channels, out_channels, time_embedding_dim=None, eps=1e-5):
1222
+ super().__init__()
1223
+
1224
+ if time_embedding_dim is not None:
1225
+ self.time_emb_proj = nn.Linear(time_embedding_dim, out_channels)
1226
+ else:
1227
+ self.time_emb_proj = None
1228
+
1229
+ self.norm1 = torch.nn.GroupNorm(32, in_channels, eps=eps)
1230
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
1231
+
1232
+ self.norm2 = nn.GroupNorm(32, out_channels, eps=eps)
1233
+ self.dropout = nn.Dropout(0.0)
1234
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
1235
+
1236
+ self.nonlinearity = nn.SiLU()
1237
+
1238
+ if in_channels != out_channels:
1239
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
1240
+ else:
1241
+ self.conv_shortcut = None
1242
+
1243
+ def forward(self, hidden_states, temb=None):
1244
+ residual = hidden_states
1245
+
1246
+ if self.time_emb_proj is not None:
1247
+ assert temb is not None
1248
+ temb = self.nonlinearity(temb)
1249
+ temb = self.time_emb_proj(temb)[:, :, None, None]
1250
+
1251
+ hidden_states = self.norm1(hidden_states)
1252
+ hidden_states = self.nonlinearity(hidden_states)
1253
+ hidden_states = self.conv1(hidden_states)
1254
+
1255
+ if temb is not None:
1256
+ hidden_states = hidden_states + temb
1257
+
1258
+ hidden_states = self.norm2(hidden_states)
1259
+ hidden_states = self.nonlinearity(hidden_states)
1260
+ hidden_states = self.dropout(hidden_states)
1261
+ hidden_states = self.conv2(hidden_states)
1262
+
1263
+ if self.conv_shortcut is not None:
1264
+ residual = self.conv_shortcut(residual)
1265
+
1266
+ hidden_states = hidden_states + residual
1267
+
1268
+ return hidden_states
1269
+
1270
+
1271
+ class TransformerDecoder2D(nn.Module):
1272
+ def __init__(self, channels, encoder_hidden_states_dim, num_transformer_blocks):
1273
+ super().__init__()
1274
+
1275
+ self.norm = nn.GroupNorm(32, channels, eps=1e-06)
1276
+ self.proj_in = nn.Linear(channels, channels)
1277
+
1278
+ self.transformer_blocks = nn.ModuleList([TransformerDecoderBlock(channels, encoder_hidden_states_dim) for _ in range(num_transformer_blocks)])
1279
+
1280
+ self.proj_out = nn.Linear(channels, channels)
1281
+
1282
+ def forward(self, hidden_states, encoder_hidden_states):
1283
+ batch_size, channels, height, width = hidden_states.shape
1284
+
1285
+ residual = hidden_states
1286
+
1287
+ hidden_states = self.norm(hidden_states)
1288
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
1289
+ hidden_states = self.proj_in(hidden_states)
1290
+
1291
+ for block in self.transformer_blocks:
1292
+ hidden_states = block(hidden_states, encoder_hidden_states)
1293
+
1294
+ hidden_states = self.proj_out(hidden_states)
1295
+ hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2).contiguous()
1296
+
1297
+ hidden_states = hidden_states + residual
1298
+
1299
+ return hidden_states
1300
+
1301
+
1302
+ class TransformerDecoderBlock(nn.Module):
1303
+ def __init__(self, channels, encoder_hidden_states_dim):
1304
+ super().__init__()
1305
+
1306
+ self.norm1 = nn.LayerNorm(channels)
1307
+ self.attn1 = Attention(channels, channels)
1308
+
1309
+ self.norm2 = nn.LayerNorm(channels)
1310
+ self.attn2 = Attention(channels, encoder_hidden_states_dim)
1311
+
1312
+ self.norm3 = nn.LayerNorm(channels)
1313
+ self.ff = nn.ModuleDict(dict(net=nn.Sequential(GEGLU(channels, 4 * channels), nn.Dropout(0.0), nn.Linear(4 * channels, channels))))
1314
+
1315
+ def forward(self, hidden_states, encoder_hidden_states):
1316
+ hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
1317
+
1318
+ hidden_states = self.attn2(self.norm2(hidden_states), encoder_hidden_states) + hidden_states
1319
+
1320
+ hidden_states = self.ff["net"](self.norm3(hidden_states)) + hidden_states
1321
+
1322
+ return hidden_states
1323
+
1324
+
1325
+ class Attention(nn.Module):
1326
+ def __init__(self, channels, encoder_hidden_states_dim, qkv_bias=False):
1327
+ super().__init__()
1328
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
1329
+ self.to_k = nn.Linear(encoder_hidden_states_dim, channels, bias=qkv_bias)
1330
+ self.to_v = nn.Linear(encoder_hidden_states_dim, channels, bias=qkv_bias)
1331
+ self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
1332
+
1333
+ def forward(self, hidden_states, encoder_hidden_states=None):
1334
+ batch_size, q_seq_len, channels = hidden_states.shape
1335
+ head_dim = 64
1336
+
1337
+ if encoder_hidden_states is not None:
1338
+ kv = encoder_hidden_states
1339
+ else:
1340
+ kv = hidden_states
1341
+
1342
+ kv_seq_len = kv.shape[1]
1343
+
1344
+ query = self.to_q(hidden_states)
1345
+ key = self.to_k(kv)
1346
+ value = self.to_v(kv)
1347
+
1348
+ query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).contiguous()
1349
+ key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
1350
+ value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
1351
+
1352
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value)
1353
+
1354
+ hidden_states = hidden_states.to(query.dtype)
1355
+ hidden_states = hidden_states.reshape(batch_size, q_seq_len, channels).contiguous()
1356
+
1357
+ hidden_states = self.to_out(hidden_states)
1358
+
1359
+ return hidden_states
1360
+
1361
+
1362
+ class GEGLU(nn.Module):
1363
+ def __init__(self, dim_in: int, dim_out: int):
1364
+ super().__init__()
1365
+ self.proj = nn.Linear(dim_in, dim_out * 2)
1366
+
1367
+ def forward(self, hidden_states):
1368
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
1369
+ return hidden_states * F.gelu(gate)
1370
+
1371
+
1372
+ def zero_module(module):
1373
+ for p in module.parameters():
1374
+ nn.init.zeros_(p)
1375
+ return module