williamberman
commited on
Commit
•
f0e6b7a
1
Parent(s):
3e48ac3
init comparison
Browse files- app.py +16 -4
- diffusion.py +58 -0
- load_state_dict_patch.py +415 -0
- sdxl.py +962 -0
- sdxl_models.py +1375 -0
app.py
CHANGED
@@ -1,13 +1,20 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
|
4 |
-
from diffusers import AutoPipelineForInpainting
|
5 |
import diffusers
|
6 |
from share_btn import community_icon_html, loading_icon_html, share_js
|
|
|
|
|
7 |
|
8 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to(device)
|
10 |
|
|
|
|
|
|
|
|
|
|
|
11 |
def read_content(file_path: str) -> str:
|
12 |
"""read the content of target file
|
13 |
"""
|
@@ -34,8 +41,12 @@ def predict(dict, prompt="", negative_prompt="", guidance_scale=7.5, steps=20, s
|
|
34 |
mask = dict["mask"].convert("RGB").resize((1024, 1024))
|
35 |
|
36 |
output = pipe(prompt = prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength)
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
return output.images[0], gr.update(visible=True)
|
39 |
|
40 |
|
41 |
css = '''
|
@@ -98,14 +109,15 @@ with image_blocks as demo:
|
|
98 |
|
99 |
with gr.Column():
|
100 |
image_out = gr.Image(label="Output", elem_id="output-img", height=400)
|
|
|
101 |
with gr.Group(elem_id="share-btn-container", visible=False) as share_btn_container:
|
102 |
community_icon = gr.HTML(community_icon_html)
|
103 |
loading_icon = gr.HTML(loading_icon_html)
|
104 |
share_button = gr.Button("Share to community", elem_id="share-btn",visible=True)
|
105 |
|
106 |
|
107 |
-
btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, share_btn_container], api_name='run')
|
108 |
-
prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, share_btn_container])
|
109 |
share_button.click(None, [], [], _js=share_js)
|
110 |
|
111 |
gr.Examples(
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
|
4 |
+
from diffusers import AutoPipelineForInpainting
|
5 |
import diffusers
|
6 |
from share_btn import community_icon_html, loading_icon_html, share_js
|
7 |
+
from sdxl import gen_sdxl_simplified_interface
|
8 |
+
from sdxl_models import SDXLUNet, SDXLVae, SDXLControlNetPreEncodedControlnetCond
|
9 |
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to(device)
|
12 |
|
13 |
+
comparing_unet = SDXLUNet.load_fp16(device=device)
|
14 |
+
comparing_vae = SDXLVae.load_fp16_fix(device=device)
|
15 |
+
comparing_controlnet = SDXLControlNetPreEncodedControlnetCond.load("", device="cuda") # TODO - upload checkpoint
|
16 |
+
comparing_controlnet.to(torch.float16)
|
17 |
+
|
18 |
def read_content(file_path: str) -> str:
|
19 |
"""read the content of target file
|
20 |
"""
|
|
|
41 |
mask = dict["mask"].convert("RGB").resize((1024, 1024))
|
42 |
|
43 |
output = pipe(prompt = prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength)
|
44 |
+
output_controlnet_vae_encoding = gen_sdxl_simplified_interface(
|
45 |
+
prompt=prompt, negative_prompt=negative_prompt, images=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps),
|
46 |
+
text_encoder_one=pipe.text_encoder, text_encoder_two=pipe.text_encoder_2, unet=comparing_unet, vae=comparing_vae, controlnet=comparing_controlnet, device=device
|
47 |
+
)
|
48 |
|
49 |
+
return output.images[0], output_controlnet_vae_encoding[0], gr.update(visible=True)
|
50 |
|
51 |
|
52 |
css = '''
|
|
|
109 |
|
110 |
with gr.Column():
|
111 |
image_out = gr.Image(label="Output", elem_id="output-img", height=400)
|
112 |
+
image_out_comparing = gr.Image(label="Output", elem_id="output-img-comparing", height=400)
|
113 |
with gr.Group(elem_id="share-btn-container", visible=False) as share_btn_container:
|
114 |
community_icon = gr.HTML(community_icon_html)
|
115 |
loading_icon = gr.HTML(loading_icon_html)
|
116 |
share_button = gr.Button("Share to community", elem_id="share-btn",visible=True)
|
117 |
|
118 |
|
119 |
+
btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container], api_name='run')
|
120 |
+
prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out, image_out_comparing, share_btn_container])
|
121 |
share_button.click(None, [], [], _js=share_js)
|
122 |
|
123 |
gr.Examples(
|
diffusion.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
default_num_train_timesteps = 1000
|
4 |
+
|
5 |
+
|
6 |
+
@torch.no_grad()
|
7 |
+
def make_sigmas(beta_start=0.00085, beta_end=0.012, num_train_timesteps=default_num_train_timesteps, device=None):
|
8 |
+
betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32, device=device) ** 2
|
9 |
+
|
10 |
+
alphas = 1.0 - betas
|
11 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
12 |
+
|
13 |
+
# TODO - would be nice to use a direct expression for this
|
14 |
+
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
15 |
+
|
16 |
+
return sigmas
|
17 |
+
|
18 |
+
|
19 |
+
@torch.no_grad()
|
20 |
+
def rk_ode_solver_diffusion_loop(eps_theta, timesteps, sigmas, x_T, rk_steps_weights):
|
21 |
+
x_t = x_T
|
22 |
+
|
23 |
+
for i in range(len(timesteps) - 1, -1, -1):
|
24 |
+
t = timesteps[i]
|
25 |
+
|
26 |
+
sigma = sigmas[i]
|
27 |
+
|
28 |
+
if i == 0:
|
29 |
+
eps_hat = eps_theta(x_t=x_t, t=t, sigma=sigma)
|
30 |
+
x_0_hat = x_t - sigma * eps_hat
|
31 |
+
else:
|
32 |
+
dt = sigmas[i - 1] - sigma
|
33 |
+
|
34 |
+
dx_by_dt = torch.zeros_like(x_t)
|
35 |
+
dx_by_dt_cur = torch.zeros_like(x_t)
|
36 |
+
|
37 |
+
for rk_step, rk_weight in rk_steps_weights:
|
38 |
+
dt_ = dt * rk_step
|
39 |
+
t_ = t + dt_
|
40 |
+
x_t_ = x_t + dx_by_dt_cur * dt_
|
41 |
+
eps_hat = eps_theta(x_t=x_t_, t=t_, sigma=sigma)
|
42 |
+
# TODO - note which specific ode this is the solution to and
|
43 |
+
# how input scaling does/doesn't effect the solution
|
44 |
+
dx_by_dt_cur = (x_t_ - sigma * eps_hat) / sigma
|
45 |
+
dx_by_dt += dx_by_dt_cur * rk_weight
|
46 |
+
|
47 |
+
x_t_minus_1 = x_t + dx_by_dt * dt
|
48 |
+
|
49 |
+
x_t = x_t_minus_1
|
50 |
+
|
51 |
+
return x_0_hat
|
52 |
+
|
53 |
+
|
54 |
+
euler_ode_solver_diffusion_loop = lambda *args, **kwargs: rk_ode_solver_diffusion_loop(*args, **kwargs, rk_steps_weights=[[0, 1]])
|
55 |
+
|
56 |
+
heun_ode_solver_diffusion_loop = lambda *args, **kwargs: rk_ode_solver_diffusion_loop(*args, **kwargs, rk_steps_weights=[[0, 0.5], [1, 0.5]])
|
57 |
+
|
58 |
+
rk4_ode_solver_diffusion_loop = lambda *args, **kwargs: rk_ode_solver_diffusion_loop(*args, **kwargs, rk_steps_weights=[[0, 1 / 6], [1 / 2, 1 / 3], [1 / 2, 1 / 3], [1, 1 / 6]])
|
load_state_dict_patch.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
from collections import OrderedDict
|
3 |
+
from typing import Any, List, Mapping
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.nn import Module
|
7 |
+
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
8 |
+
|
9 |
+
# fmt: off
|
10 |
+
|
11 |
+
# this patch is for adding the `assign` key to load_state_dict.
|
12 |
+
# the code is in pytorch source for version 2.1
|
13 |
+
|
14 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
15 |
+
missing_keys, unexpected_keys, error_msgs):
|
16 |
+
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
17 |
+
this module, but not its descendants. This is called on every submodule
|
18 |
+
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
19 |
+
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
20 |
+
For state dicts without metadata, :attr:`local_metadata` is empty.
|
21 |
+
Subclasses can achieve class-specific backward compatible loading using
|
22 |
+
the version number at `local_metadata.get("version", None)`.
|
23 |
+
Additionally, :attr:`local_metadata` can also contain the key
|
24 |
+
`assign_to_params_buffers` that indicates whether keys should be
|
25 |
+
assigned their corresponding tensor in the state_dict.
|
26 |
+
|
27 |
+
.. note::
|
28 |
+
:attr:`state_dict` is not the same object as the input
|
29 |
+
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
30 |
+
it can be modified.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
state_dict (dict): a dict containing parameters and
|
34 |
+
persistent buffers.
|
35 |
+
prefix (str): the prefix for parameters and buffers used in this
|
36 |
+
module
|
37 |
+
local_metadata (dict): a dict containing the metadata for this module.
|
38 |
+
See
|
39 |
+
strict (bool): whether to strictly enforce that the keys in
|
40 |
+
:attr:`state_dict` with :attr:`prefix` match the names of
|
41 |
+
parameters and buffers in this module
|
42 |
+
missing_keys (list of str): if ``strict=True``, add missing keys to
|
43 |
+
this list
|
44 |
+
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
45 |
+
keys to this list
|
46 |
+
error_msgs (list of str): error messages should be added to this
|
47 |
+
list, and will be reported together in
|
48 |
+
:meth:`~torch.nn.Module.load_state_dict`
|
49 |
+
"""
|
50 |
+
for hook in self._load_state_dict_pre_hooks.values():
|
51 |
+
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
52 |
+
|
53 |
+
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
54 |
+
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
55 |
+
local_state = {k: v for k, v in local_name_params if v is not None}
|
56 |
+
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
57 |
+
|
58 |
+
for name, param in local_state.items():
|
59 |
+
key = prefix + name
|
60 |
+
if key in state_dict:
|
61 |
+
input_param = state_dict[key]
|
62 |
+
if not torch.overrides.is_tensor_like(input_param):
|
63 |
+
error_msgs.append('While copying the parameter named "{}", '
|
64 |
+
'expected torch.Tensor or Tensor-like object from checkpoint but '
|
65 |
+
'received {}'
|
66 |
+
.format(key, type(input_param)))
|
67 |
+
continue
|
68 |
+
|
69 |
+
# This is used to avoid copying uninitialized parameters into
|
70 |
+
# non-lazy modules, since they dont have the hook to do the checks
|
71 |
+
# in such case, it will error when accessing the .shape attribute.
|
72 |
+
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
73 |
+
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
74 |
+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
75 |
+
input_param = input_param[0]
|
76 |
+
|
77 |
+
if not is_param_lazy and input_param.shape != param.shape:
|
78 |
+
# local shape should match the one in checkpoint
|
79 |
+
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
80 |
+
'the shape in current model is {}.'
|
81 |
+
.format(key, input_param.shape, param.shape))
|
82 |
+
continue
|
83 |
+
try:
|
84 |
+
with torch.no_grad():
|
85 |
+
if assign_to_params_buffers:
|
86 |
+
# Shape checks are already done above
|
87 |
+
if (isinstance(param, torch.nn.Parameter) and
|
88 |
+
not isinstance(input_param, torch.nn.Parameter)):
|
89 |
+
setattr(self, name, torch.nn.Parameter(input_param))
|
90 |
+
else:
|
91 |
+
setattr(self, name, input_param)
|
92 |
+
else:
|
93 |
+
param.copy_(input_param)
|
94 |
+
except Exception as ex:
|
95 |
+
error_msgs.append('While copying the parameter named "{}", '
|
96 |
+
'whose dimensions in the model are {} and '
|
97 |
+
'whose dimensions in the checkpoint are {}, '
|
98 |
+
'an exception occurred : {}.'
|
99 |
+
.format(key, param.size(), input_param.size(), ex.args))
|
100 |
+
elif strict:
|
101 |
+
missing_keys.append(key)
|
102 |
+
|
103 |
+
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
104 |
+
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
|
105 |
+
if extra_state_key in state_dict:
|
106 |
+
self.set_extra_state(state_dict[extra_state_key])
|
107 |
+
elif strict:
|
108 |
+
missing_keys.append(extra_state_key)
|
109 |
+
elif strict and (extra_state_key in state_dict):
|
110 |
+
unexpected_keys.append(extra_state_key)
|
111 |
+
|
112 |
+
if strict:
|
113 |
+
for key in state_dict.keys():
|
114 |
+
if key.startswith(prefix) and key != extra_state_key:
|
115 |
+
input_name = key[len(prefix):]
|
116 |
+
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
|
117 |
+
if input_name not in self._modules and input_name not in local_state:
|
118 |
+
unexpected_keys.append(key)
|
119 |
+
|
120 |
+
def load_state_dict(self, state_dict: Mapping[str, Any],
|
121 |
+
strict: bool = True, assign: bool = False):
|
122 |
+
r"""Copies parameters and buffers from :attr:`state_dict` into
|
123 |
+
this module and its descendants. If :attr:`strict` is ``True``, then
|
124 |
+
the keys of :attr:`state_dict` must exactly match the keys returned
|
125 |
+
by this module's :meth:`~torch.nn.Module.state_dict` function.
|
126 |
+
|
127 |
+
.. warning::
|
128 |
+
If :attr:`assign` is ``True`` the optimizer must be created after
|
129 |
+
the call to :attr:`load_state_dict`.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
state_dict (dict): a dict containing parameters and
|
133 |
+
persistent buffers.
|
134 |
+
strict (bool, optional): whether to strictly enforce that the keys
|
135 |
+
in :attr:`state_dict` match the keys returned by this module's
|
136 |
+
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
|
137 |
+
assign (bool, optional): whether to assign items in the state
|
138 |
+
dictionary to their corresponding keys in the module instead
|
139 |
+
of copying them inplace into the module's current parameters and buffers.
|
140 |
+
When ``False``, the properties of the tensors in the current
|
141 |
+
module are preserved while when ``True``, the properties of the
|
142 |
+
Tensors in the state dict are preserved.
|
143 |
+
Default: ``False``
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
|
147 |
+
* **missing_keys** is a list of str containing the missing keys
|
148 |
+
* **unexpected_keys** is a list of str containing the unexpected keys
|
149 |
+
|
150 |
+
Note:
|
151 |
+
If a parameter or buffer is registered as ``None`` and its corresponding key
|
152 |
+
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
|
153 |
+
``RuntimeError``.
|
154 |
+
"""
|
155 |
+
if not isinstance(state_dict, Mapping):
|
156 |
+
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
|
157 |
+
|
158 |
+
missing_keys: List[str] = []
|
159 |
+
unexpected_keys: List[str] = []
|
160 |
+
error_msgs: List[str] = []
|
161 |
+
|
162 |
+
# copy state_dict so _load_from_state_dict can modify it
|
163 |
+
metadata = getattr(state_dict, '_metadata', None)
|
164 |
+
state_dict = OrderedDict(state_dict)
|
165 |
+
if metadata is not None:
|
166 |
+
# mypy isn't aware that "_metadata" exists in state_dict
|
167 |
+
state_dict._metadata = metadata # type: ignore[attr-defined]
|
168 |
+
|
169 |
+
def load(module, local_state_dict, prefix=''):
|
170 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
171 |
+
if assign:
|
172 |
+
local_metadata['assign_to_params_buffers'] = assign
|
173 |
+
module._load_from_state_dict(
|
174 |
+
local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
175 |
+
for name, child in module._modules.items():
|
176 |
+
if child is not None:
|
177 |
+
child_prefix = prefix + name + '.'
|
178 |
+
child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
|
179 |
+
load(child, child_state_dict, child_prefix)
|
180 |
+
|
181 |
+
# Note that the hook can modify missing_keys and unexpected_keys.
|
182 |
+
incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
|
183 |
+
for hook in module._load_state_dict_post_hooks.values():
|
184 |
+
out = hook(module, incompatible_keys)
|
185 |
+
assert out is None, (
|
186 |
+
"Hooks registered with ``register_load_state_dict_post_hook`` are not"
|
187 |
+
"expected to return new values, if incompatible_keys need to be modified,"
|
188 |
+
"it should be done inplace."
|
189 |
+
)
|
190 |
+
|
191 |
+
load(self, state_dict)
|
192 |
+
del load
|
193 |
+
|
194 |
+
if strict:
|
195 |
+
if len(unexpected_keys) > 0:
|
196 |
+
error_msgs.insert(
|
197 |
+
0, 'Unexpected key(s) in state_dict: {}. '.format(
|
198 |
+
', '.join('"{}"'.format(k) for k in unexpected_keys)))
|
199 |
+
if len(missing_keys) > 0:
|
200 |
+
error_msgs.insert(
|
201 |
+
0, 'Missing key(s) in state_dict: {}. '.format(
|
202 |
+
', '.join('"{}"'.format(k) for k in missing_keys)))
|
203 |
+
|
204 |
+
if len(error_msgs) > 0:
|
205 |
+
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
206 |
+
self.__class__.__name__, "\n\t".join(error_msgs)))
|
207 |
+
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
208 |
+
|
209 |
+
if [int(x) for x in torch.__version__.split('.')[0:2]] < [2, 1]:
|
210 |
+
Module._load_from_state_dict = _load_from_state_dict
|
211 |
+
Module.load_state_dict = load_state_dict
|
212 |
+
|
213 |
+
# this patch is for adding the `assign` key to load_state_dict.
|
214 |
+
# the code is in pytorch source for version 2.1
|
215 |
+
|
216 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
217 |
+
missing_keys, unexpected_keys, error_msgs):
|
218 |
+
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
219 |
+
this module, but not its descendants. This is called on every submodule
|
220 |
+
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
221 |
+
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
222 |
+
For state dicts without metadata, :attr:`local_metadata` is empty.
|
223 |
+
Subclasses can achieve class-specific backward compatible loading using
|
224 |
+
the version number at `local_metadata.get("version", None)`.
|
225 |
+
Additionally, :attr:`local_metadata` can also contain the key
|
226 |
+
`assign_to_params_buffers` that indicates whether keys should be
|
227 |
+
assigned their corresponding tensor in the state_dict.
|
228 |
+
|
229 |
+
.. note::
|
230 |
+
:attr:`state_dict` is not the same object as the input
|
231 |
+
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
232 |
+
it can be modified.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
state_dict (dict): a dict containing parameters and
|
236 |
+
persistent buffers.
|
237 |
+
prefix (str): the prefix for parameters and buffers used in this
|
238 |
+
module
|
239 |
+
local_metadata (dict): a dict containing the metadata for this module.
|
240 |
+
See
|
241 |
+
strict (bool): whether to strictly enforce that the keys in
|
242 |
+
:attr:`state_dict` with :attr:`prefix` match the names of
|
243 |
+
parameters and buffers in this module
|
244 |
+
missing_keys (list of str): if ``strict=True``, add missing keys to
|
245 |
+
this list
|
246 |
+
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
247 |
+
keys to this list
|
248 |
+
error_msgs (list of str): error messages should be added to this
|
249 |
+
list, and will be reported together in
|
250 |
+
:meth:`~torch.nn.Module.load_state_dict`
|
251 |
+
"""
|
252 |
+
for hook in self._load_state_dict_pre_hooks.values():
|
253 |
+
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
254 |
+
|
255 |
+
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
256 |
+
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
257 |
+
local_state = {k: v for k, v in local_name_params if v is not None}
|
258 |
+
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
259 |
+
|
260 |
+
for name, param in local_state.items():
|
261 |
+
key = prefix + name
|
262 |
+
if key in state_dict:
|
263 |
+
input_param = state_dict[key]
|
264 |
+
if not torch.overrides.is_tensor_like(input_param):
|
265 |
+
error_msgs.append('While copying the parameter named "{}", '
|
266 |
+
'expected torch.Tensor or Tensor-like object from checkpoint but '
|
267 |
+
'received {}'
|
268 |
+
.format(key, type(input_param)))
|
269 |
+
continue
|
270 |
+
|
271 |
+
# This is used to avoid copying uninitialized parameters into
|
272 |
+
# non-lazy modules, since they dont have the hook to do the checks
|
273 |
+
# in such case, it will error when accessing the .shape attribute.
|
274 |
+
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
275 |
+
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
276 |
+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
277 |
+
input_param = input_param[0]
|
278 |
+
|
279 |
+
if not is_param_lazy and input_param.shape != param.shape:
|
280 |
+
# local shape should match the one in checkpoint
|
281 |
+
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
282 |
+
'the shape in current model is {}.'
|
283 |
+
.format(key, input_param.shape, param.shape))
|
284 |
+
continue
|
285 |
+
try:
|
286 |
+
with torch.no_grad():
|
287 |
+
if assign_to_params_buffers:
|
288 |
+
# Shape checks are already done above
|
289 |
+
if (isinstance(param, torch.nn.Parameter) and
|
290 |
+
not isinstance(input_param, torch.nn.Parameter)):
|
291 |
+
setattr(self, name, torch.nn.Parameter(input_param))
|
292 |
+
else:
|
293 |
+
setattr(self, name, input_param)
|
294 |
+
else:
|
295 |
+
param.copy_(input_param)
|
296 |
+
except Exception as ex:
|
297 |
+
error_msgs.append('While copying the parameter named "{}", '
|
298 |
+
'whose dimensions in the model are {} and '
|
299 |
+
'whose dimensions in the checkpoint are {}, '
|
300 |
+
'an exception occurred : {}.'
|
301 |
+
.format(key, param.size(), input_param.size(), ex.args))
|
302 |
+
elif strict:
|
303 |
+
missing_keys.append(key)
|
304 |
+
|
305 |
+
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
306 |
+
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
|
307 |
+
if extra_state_key in state_dict:
|
308 |
+
self.set_extra_state(state_dict[extra_state_key])
|
309 |
+
elif strict:
|
310 |
+
missing_keys.append(extra_state_key)
|
311 |
+
elif strict and (extra_state_key in state_dict):
|
312 |
+
unexpected_keys.append(extra_state_key)
|
313 |
+
|
314 |
+
if strict:
|
315 |
+
for key in state_dict.keys():
|
316 |
+
if key.startswith(prefix) and key != extra_state_key:
|
317 |
+
input_name = key[len(prefix):]
|
318 |
+
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
|
319 |
+
if input_name not in self._modules and input_name not in local_state:
|
320 |
+
unexpected_keys.append(key)
|
321 |
+
|
322 |
+
def load_state_dict(self, state_dict: Mapping[str, Any],
|
323 |
+
strict: bool = True, assign: bool = False):
|
324 |
+
r"""Copies parameters and buffers from :attr:`state_dict` into
|
325 |
+
this module and its descendants. If :attr:`strict` is ``True``, then
|
326 |
+
the keys of :attr:`state_dict` must exactly match the keys returned
|
327 |
+
by this module's :meth:`~torch.nn.Module.state_dict` function.
|
328 |
+
|
329 |
+
.. warning::
|
330 |
+
If :attr:`assign` is ``True`` the optimizer must be created after
|
331 |
+
the call to :attr:`load_state_dict`.
|
332 |
+
|
333 |
+
Args:
|
334 |
+
state_dict (dict): a dict containing parameters and
|
335 |
+
persistent buffers.
|
336 |
+
strict (bool, optional): whether to strictly enforce that the keys
|
337 |
+
in :attr:`state_dict` match the keys returned by this module's
|
338 |
+
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
|
339 |
+
assign (bool, optional): whether to assign items in the state
|
340 |
+
dictionary to their corresponding keys in the module instead
|
341 |
+
of copying them inplace into the module's current parameters and buffers.
|
342 |
+
When ``False``, the properties of the tensors in the current
|
343 |
+
module are preserved while when ``True``, the properties of the
|
344 |
+
Tensors in the state dict are preserved.
|
345 |
+
Default: ``False``
|
346 |
+
|
347 |
+
Returns:
|
348 |
+
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
|
349 |
+
* **missing_keys** is a list of str containing the missing keys
|
350 |
+
* **unexpected_keys** is a list of str containing the unexpected keys
|
351 |
+
|
352 |
+
Note:
|
353 |
+
If a parameter or buffer is registered as ``None`` and its corresponding key
|
354 |
+
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
|
355 |
+
``RuntimeError``.
|
356 |
+
"""
|
357 |
+
if not isinstance(state_dict, Mapping):
|
358 |
+
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
|
359 |
+
|
360 |
+
missing_keys: List[str] = []
|
361 |
+
unexpected_keys: List[str] = []
|
362 |
+
error_msgs: List[str] = []
|
363 |
+
|
364 |
+
# copy state_dict so _load_from_state_dict can modify it
|
365 |
+
metadata = getattr(state_dict, '_metadata', None)
|
366 |
+
state_dict = OrderedDict(state_dict)
|
367 |
+
if metadata is not None:
|
368 |
+
# mypy isn't aware that "_metadata" exists in state_dict
|
369 |
+
state_dict._metadata = metadata # type: ignore[attr-defined]
|
370 |
+
|
371 |
+
def load(module, local_state_dict, prefix=''):
|
372 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
373 |
+
if assign:
|
374 |
+
local_metadata['assign_to_params_buffers'] = assign
|
375 |
+
module._load_from_state_dict(
|
376 |
+
local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
377 |
+
for name, child in module._modules.items():
|
378 |
+
if child is not None:
|
379 |
+
child_prefix = prefix + name + '.'
|
380 |
+
child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
|
381 |
+
load(child, child_state_dict, child_prefix)
|
382 |
+
|
383 |
+
# Note that the hook can modify missing_keys and unexpected_keys.
|
384 |
+
incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
|
385 |
+
for hook in module._load_state_dict_post_hooks.values():
|
386 |
+
out = hook(module, incompatible_keys)
|
387 |
+
assert out is None, (
|
388 |
+
"Hooks registered with ``register_load_state_dict_post_hook`` are not"
|
389 |
+
"expected to return new values, if incompatible_keys need to be modified,"
|
390 |
+
"it should be done inplace."
|
391 |
+
)
|
392 |
+
|
393 |
+
load(self, state_dict)
|
394 |
+
del load
|
395 |
+
|
396 |
+
if strict:
|
397 |
+
if len(unexpected_keys) > 0:
|
398 |
+
error_msgs.insert(
|
399 |
+
0, 'Unexpected key(s) in state_dict: {}. '.format(
|
400 |
+
', '.join('"{}"'.format(k) for k in unexpected_keys)))
|
401 |
+
if len(missing_keys) > 0:
|
402 |
+
error_msgs.insert(
|
403 |
+
0, 'Missing key(s) in state_dict: {}. '.format(
|
404 |
+
', '.join('"{}"'.format(k) for k in missing_keys)))
|
405 |
+
|
406 |
+
if len(error_msgs) > 0:
|
407 |
+
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
408 |
+
self.__class__.__name__, "\n\t".join(error_msgs)))
|
409 |
+
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
410 |
+
|
411 |
+
if [int(x) for x in torch.__version__.split('.')[0:2]] < [2, 1]:
|
412 |
+
Module._load_from_state_dict = _load_from_state_dict
|
413 |
+
Module.load_state_dict = load_state_dict
|
414 |
+
|
415 |
+
# fmt: on
|
sdxl.py
ADDED
@@ -0,0 +1,962 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import safetensors.torch
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torchvision.transforms
|
11 |
+
import torchvision.transforms.functional as TF
|
12 |
+
import wandb
|
13 |
+
import webdataset as wds
|
14 |
+
from PIL import Image
|
15 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
16 |
+
from torch.utils.data import default_collate
|
17 |
+
from transformers import (CLIPTextModel, CLIPTextModelWithProjection,
|
18 |
+
CLIPTokenizerFast)
|
19 |
+
|
20 |
+
from diffusion import (default_num_train_timesteps,
|
21 |
+
euler_ode_solver_diffusion_loop, make_sigmas)
|
22 |
+
from sdxl_models import (SDXLAdapter, SDXLControlNet, SDXLControlNetFull,
|
23 |
+
SDXLControlNetPreEncodedControlnetCond, SDXLUNet,
|
24 |
+
SDXLVae)
|
25 |
+
|
26 |
+
|
27 |
+
class SDXLTraining:
|
28 |
+
text_encoder_one: CLIPTextModel
|
29 |
+
text_encoder_two: CLIPTextModelWithProjection
|
30 |
+
vae: SDXLVae
|
31 |
+
sigmas: torch.Tensor
|
32 |
+
unet: SDXLUNet
|
33 |
+
adapter: Optional[SDXLAdapter]
|
34 |
+
controlnet: Optional[Union[SDXLControlNet, SDXLControlNetFull]]
|
35 |
+
|
36 |
+
train_unet: bool
|
37 |
+
train_unet_up_blocks: bool
|
38 |
+
|
39 |
+
mixed_precision: Optional[torch.dtype]
|
40 |
+
timestep_sampling: Literal["uniform", "cubic"]
|
41 |
+
|
42 |
+
validation_images_logged: bool
|
43 |
+
log_validation_input_images_every_time: bool
|
44 |
+
|
45 |
+
get_sdxl_conditioning_images: Callable[[Image.Image], Dict[str, Any]]
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
device,
|
50 |
+
train_unet,
|
51 |
+
get_sdxl_conditioning_images,
|
52 |
+
train_unet_up_blocks=False,
|
53 |
+
unet_resume_from=None,
|
54 |
+
controlnet_cls=None,
|
55 |
+
controlnet_resume_from=None,
|
56 |
+
adapter_cls=None,
|
57 |
+
adapter_resume_from=None,
|
58 |
+
mixed_precision=None,
|
59 |
+
timestep_sampling="uniform",
|
60 |
+
log_validation_input_images_every_time=True,
|
61 |
+
):
|
62 |
+
self.text_encoder_one = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", variant="fp16", torch_dtype=torch.float16)
|
63 |
+
self.text_encoder_one.to(device=device)
|
64 |
+
self.text_encoder_one.requires_grad_(False)
|
65 |
+
self.text_encoder_one.eval()
|
66 |
+
|
67 |
+
self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2", variant="fp16", torch_dtype=torch.float16)
|
68 |
+
self.text_encoder_two.to(device=device)
|
69 |
+
self.text_encoder_two.requires_grad_(False)
|
70 |
+
self.text_encoder_two.eval()
|
71 |
+
|
72 |
+
self.vae = SDXLVae.load_fp16_fix(device=device)
|
73 |
+
self.vae.requires_grad_(False)
|
74 |
+
self.vae.eval()
|
75 |
+
|
76 |
+
self.sigmas = make_sigmas(device=device)
|
77 |
+
|
78 |
+
if train_unet:
|
79 |
+
if unet_resume_from is None:
|
80 |
+
self.unet = SDXLUNet.load_fp32(device=device)
|
81 |
+
else:
|
82 |
+
self.unet = SDXLUNet.load(unet_resume_from, device=device)
|
83 |
+
self.unet.requires_grad_(True)
|
84 |
+
self.unet.train()
|
85 |
+
self.unet = DDP(self.unet, device_ids=[device])
|
86 |
+
elif train_unet_up_blocks:
|
87 |
+
if unet_resume_from is None:
|
88 |
+
self.unet = SDXLUNet.load_fp32(device=device)
|
89 |
+
else:
|
90 |
+
self.unet = SDXLUNet.load_fp32(device=device, overrides=[unet_resume_from])
|
91 |
+
self.unet.requires_grad_(False)
|
92 |
+
self.unet.eval()
|
93 |
+
self.unet.up_blocks.requires_grad_(True)
|
94 |
+
self.unet.up_blocks.train()
|
95 |
+
self.unet = DDP(self.unet, device_ids=[device], find_unused_parameters=True)
|
96 |
+
else:
|
97 |
+
self.unet = SDXLUNet.load_fp16(device=device)
|
98 |
+
self.unet.requires_grad_(False)
|
99 |
+
self.unet.eval()
|
100 |
+
|
101 |
+
if controlnet_cls is not None:
|
102 |
+
if controlnet_resume_from is None:
|
103 |
+
self.controlnet = controlnet_cls.from_unet(self.unet)
|
104 |
+
self.controlnet.to(device)
|
105 |
+
else:
|
106 |
+
self.controlnet = controlnet_cls.load(controlnet_resume_from, device=device)
|
107 |
+
self.controlnet.train()
|
108 |
+
self.controlnet.requires_grad_(True)
|
109 |
+
# TODO add back
|
110 |
+
# controlnet.enable_gradient_checkpointing()
|
111 |
+
# TODO - should be able to remove find_unused_parameters. Comes from pre encoded controlnet
|
112 |
+
self.controlnet = DDP(self.controlnet, device_ids=[device], find_unused_parameters=True)
|
113 |
+
else:
|
114 |
+
self.controlnet = None
|
115 |
+
|
116 |
+
if adapter_cls is not None:
|
117 |
+
if adapter_resume_from is None:
|
118 |
+
self.adapter = adapter_cls()
|
119 |
+
self.adapter.to(device=device)
|
120 |
+
else:
|
121 |
+
self.adapter = adapter_cls.load(adapter_resume_from, device=device)
|
122 |
+
self.adapter.train()
|
123 |
+
self.adapter.requires_grad_(True)
|
124 |
+
self.adapter = DDP(self.adapter, device_ids=[device])
|
125 |
+
else:
|
126 |
+
self.adapter = None
|
127 |
+
|
128 |
+
self.mixed_precision = mixed_precision
|
129 |
+
self.timestep_sampling = timestep_sampling
|
130 |
+
|
131 |
+
self.validation_images_logged = False
|
132 |
+
self.log_validation_input_images_every_time = log_validation_input_images_every_time
|
133 |
+
|
134 |
+
self.get_sdxl_conditioning_images = get_sdxl_conditioning_images
|
135 |
+
|
136 |
+
self.train_unet = train_unet
|
137 |
+
self.train_unet_up_blocks = train_unet_up_blocks
|
138 |
+
|
139 |
+
def train_step(self, batch):
|
140 |
+
with torch.no_grad():
|
141 |
+
if isinstance(self.unet, DDP):
|
142 |
+
unet_dtype = self.unet.module.dtype
|
143 |
+
unet_device = self.unet.module.device
|
144 |
+
else:
|
145 |
+
unet_dtype = self.unet.dtype
|
146 |
+
unet_device = self.unet.device
|
147 |
+
|
148 |
+
micro_conditioning = batch["micro_conditioning"].to(device=unet_device)
|
149 |
+
|
150 |
+
image = batch["image"].to(self.vae.device, dtype=self.vae.dtype)
|
151 |
+
latents = self.vae.encode(image).to(dtype=unet_dtype)
|
152 |
+
|
153 |
+
text_input_ids_one = batch["text_input_ids_one"].to(self.text_encoder_one.device)
|
154 |
+
text_input_ids_two = batch["text_input_ids_two"].to(self.text_encoder_two.device)
|
155 |
+
|
156 |
+
encoder_hidden_states, pooled_encoder_hidden_states = sdxl_text_conditioning(self.text_encoder_one, self.text_encoder_two, text_input_ids_one, text_input_ids_two)
|
157 |
+
|
158 |
+
encoder_hidden_states = encoder_hidden_states.to(dtype=unet_dtype)
|
159 |
+
pooled_encoder_hidden_states = pooled_encoder_hidden_states.to(dtype=unet_dtype)
|
160 |
+
|
161 |
+
bsz = latents.shape[0]
|
162 |
+
|
163 |
+
if self.timestep_sampling == "uniform":
|
164 |
+
timesteps = torch.randint(0, default_num_train_timesteps, (bsz,), device=unet_device)
|
165 |
+
elif self.timestep_sampling == "cubic":
|
166 |
+
# Cubic sampling to sample a random timestep for each image
|
167 |
+
timesteps = torch.rand((bsz,), device=unet_device)
|
168 |
+
timesteps = (1 - timesteps**3) * default_num_train_timesteps
|
169 |
+
timesteps = timesteps.long()
|
170 |
+
timesteps = timesteps.clamp(0, default_num_train_timesteps - 1)
|
171 |
+
else:
|
172 |
+
assert False
|
173 |
+
|
174 |
+
sigmas_ = self.sigmas[timesteps].to(dtype=latents.dtype)
|
175 |
+
|
176 |
+
noise = torch.randn_like(latents)
|
177 |
+
|
178 |
+
noisy_latents = latents + noise * sigmas_
|
179 |
+
|
180 |
+
scaled_noisy_latents = noisy_latents / ((sigmas_**2 + 1) ** 0.5)
|
181 |
+
|
182 |
+
if "conditioning_image" in batch:
|
183 |
+
conditioning_image = batch["conditioning_image"].to(unet_device)
|
184 |
+
|
185 |
+
if self.controlnet is not None and isinstance(self.controlnet, SDXLControlNetPreEncodedControlnetCond):
|
186 |
+
controlnet_device = self.controlnet.module.device
|
187 |
+
controlnet_dtype = self.controlnet.module.dtype
|
188 |
+
conditioning_image = self.vae.encode(conditioning_image.to(self.vae.dtype)).to(device=controlnet_device, dtype=controlnet_dtype)
|
189 |
+
conditioning_image_mask = TF.resize(batch["conditioning_image_mask"], conditioning_image.shape[2:]).to(device=controlnet_device, dtype=controlnet_dtype)
|
190 |
+
conditioning_image = torch.concat((conditioning_image, conditioning_image_mask), dim=1)
|
191 |
+
|
192 |
+
with torch.autocast(
|
193 |
+
"cuda",
|
194 |
+
self.mixed_precision,
|
195 |
+
enabled=self.mixed_precision is not None,
|
196 |
+
):
|
197 |
+
down_block_additional_residuals = None
|
198 |
+
mid_block_additional_residual = None
|
199 |
+
add_to_down_block_inputs = None
|
200 |
+
add_to_output = None
|
201 |
+
|
202 |
+
if self.adapter is not None:
|
203 |
+
down_block_additional_residuals = self.adapter(conditioning_image)
|
204 |
+
|
205 |
+
if self.controlnet is not None:
|
206 |
+
controlnet_out = self.controlnet(
|
207 |
+
x_t=scaled_noisy_latents,
|
208 |
+
t=timesteps,
|
209 |
+
encoder_hidden_states=encoder_hidden_states,
|
210 |
+
micro_conditioning=micro_conditioning,
|
211 |
+
pooled_encoder_hidden_states=pooled_encoder_hidden_states,
|
212 |
+
controlnet_cond=conditioning_image,
|
213 |
+
)
|
214 |
+
|
215 |
+
down_block_additional_residuals = controlnet_out["down_block_res_samples"]
|
216 |
+
mid_block_additional_residual = controlnet_out["mid_block_res_sample"]
|
217 |
+
add_to_down_block_inputs = controlnet_out.get("add_to_down_block_inputs", None)
|
218 |
+
add_to_output = controlnet_out.get("add_to_output", None)
|
219 |
+
|
220 |
+
model_pred = self.unet(
|
221 |
+
x_t=scaled_noisy_latents,
|
222 |
+
t=timesteps,
|
223 |
+
encoder_hidden_states=encoder_hidden_states,
|
224 |
+
micro_conditioning=micro_conditioning,
|
225 |
+
pooled_encoder_hidden_states=pooled_encoder_hidden_states,
|
226 |
+
down_block_additional_residuals=down_block_additional_residuals,
|
227 |
+
mid_block_additional_residual=mid_block_additional_residual,
|
228 |
+
add_to_down_block_inputs=add_to_down_block_inputs,
|
229 |
+
add_to_output=add_to_output,
|
230 |
+
).sample
|
231 |
+
|
232 |
+
loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
|
233 |
+
|
234 |
+
return loss
|
235 |
+
|
236 |
+
@torch.no_grad()
|
237 |
+
def log_validation(self, step, num_validation_images: int, validation_prompts: Optional[List[str]] = None, validation_images: Optional[List[str]] = None):
|
238 |
+
if isinstance(self.unet, DDP):
|
239 |
+
unet = self.unet.module
|
240 |
+
unet.eval()
|
241 |
+
unet_set_to_eval = True
|
242 |
+
else:
|
243 |
+
unet = self.unet
|
244 |
+
unet_set_to_eval = False
|
245 |
+
|
246 |
+
if self.adapter is not None:
|
247 |
+
adapter = self.adapter.module
|
248 |
+
adapter.eval()
|
249 |
+
else:
|
250 |
+
adapter = None
|
251 |
+
|
252 |
+
if self.controlnet is not None:
|
253 |
+
controlnet = self.controlnet.module
|
254 |
+
controlnet.eval()
|
255 |
+
else:
|
256 |
+
controlnet = None
|
257 |
+
|
258 |
+
formatted_validation_images = None
|
259 |
+
|
260 |
+
if validation_images is not None:
|
261 |
+
formatted_validation_images = []
|
262 |
+
wandb_validation_images = []
|
263 |
+
|
264 |
+
for validation_image_path in validation_images:
|
265 |
+
validation_image = Image.open(validation_image_path)
|
266 |
+
validation_image = validation_image.convert("RGB")
|
267 |
+
validation_image = validation_image.resize((1024, 1024))
|
268 |
+
|
269 |
+
conditioning_images = self.get_sdxl_conditioning_images(validation_image)
|
270 |
+
|
271 |
+
conditioning_image = conditioning_images["conditioning_image"]
|
272 |
+
|
273 |
+
if self.controlnet is not None and isinstance(self.controlnet, SDXLControlNetPreEncodedControlnetCond):
|
274 |
+
conditioning_image = self.vae.encode(conditioning_image[None, :, :, :].to(self.vae.device, dtype=self.vae.dtype))
|
275 |
+
conditionin_mask_image = TF.resize(conditioning_images["conditioning_mask_image"], conditioning_image.shape[2:]).to(conditioning_image.dtype, conditioning_image.device)
|
276 |
+
conditioning_image = torch.concat(conditioning_image, conditionin_mask_image, dim=1)
|
277 |
+
|
278 |
+
formatted_validation_images.append(conditioning_image)
|
279 |
+
wandb_validation_images.append(wandb.Image(conditioning_images["conditioning_image_as_pil"]))
|
280 |
+
|
281 |
+
if self.log_validation_input_images_every_time or not self.validation_images_logged:
|
282 |
+
wandb.log({"validation_conditioning": wandb_validation_images}, step=step)
|
283 |
+
self.validation_images_logged = True
|
284 |
+
|
285 |
+
generator = torch.Generator().manual_seed(0)
|
286 |
+
|
287 |
+
output_validation_images = []
|
288 |
+
|
289 |
+
for formatted_validation_image, validation_prompt in zip(formatted_validation_images, validation_prompts):
|
290 |
+
for _ in range(num_validation_images):
|
291 |
+
with torch.autocast("cuda"):
|
292 |
+
x_0 = sdxl_diffusion_loop(
|
293 |
+
prompts=validation_prompt,
|
294 |
+
images=formatted_validation_image,
|
295 |
+
unet=unet,
|
296 |
+
text_encoder_one=self.text_encoder_one,
|
297 |
+
text_encoder_two=self.text_encoder_two,
|
298 |
+
controlnet=controlnet,
|
299 |
+
adapter=adapter,
|
300 |
+
sigmas=self.sigmas,
|
301 |
+
generator=generator,
|
302 |
+
)
|
303 |
+
|
304 |
+
x_0 = self.vae.decode(x_0)
|
305 |
+
x_0 = self.vae.output_tensor_to_pil(x_0)[0]
|
306 |
+
|
307 |
+
output_validation_images.append(wandb.Image(x_0, caption=validation_prompt))
|
308 |
+
|
309 |
+
wandb.log({"validation": output_validation_images}, step=step)
|
310 |
+
|
311 |
+
if unet_set_to_eval:
|
312 |
+
unet.train()
|
313 |
+
|
314 |
+
if adapter is not None:
|
315 |
+
adapter.train()
|
316 |
+
|
317 |
+
if controlnet is not None:
|
318 |
+
controlnet.train()
|
319 |
+
|
320 |
+
def parameters(self):
|
321 |
+
if self.train_unet:
|
322 |
+
return self.unet.parameters()
|
323 |
+
|
324 |
+
if self.controlnet is not None and self.train_unet_up_blocks:
|
325 |
+
return itertools.chain(self.controlnet.parameters(), self.unet.up_blocks.parameters())
|
326 |
+
|
327 |
+
if self.controlnet is not None:
|
328 |
+
return self.controlnet.parameters()
|
329 |
+
|
330 |
+
if self.adapter is not None:
|
331 |
+
return self.adapter.parameters()
|
332 |
+
|
333 |
+
assert False
|
334 |
+
|
335 |
+
def save(self, save_to):
|
336 |
+
if self.train_unet:
|
337 |
+
safetensors.torch.save_file(self.unet.module.state_dict(), os.path.join(save_to, "unet.safetensors"))
|
338 |
+
|
339 |
+
if self.controlnet is not None and self.train_unet_up_blocks:
|
340 |
+
safetensors.torch.save_file(self.controlnet.module.state_dict(), os.path.join(save_to, "controlnet.safetensors"))
|
341 |
+
safetensors.torch.save_file(self.unet.module.up_blocks.state_dict(), os.path.join(save_to, "unet.safetensors"))
|
342 |
+
|
343 |
+
if self.controlnet is not None:
|
344 |
+
safetensors.torch.save_file(self.controlnet.module.state_dict(), os.path.join(save_to, "controlnet.safetensors"))
|
345 |
+
|
346 |
+
if self.adapter is not None:
|
347 |
+
safetensors.torch.save_file(self.adapter.module.state_dict(), os.path.join(save_to, "adapter.safetensors"))
|
348 |
+
|
349 |
+
|
350 |
+
def get_sdxl_dataset(train_shards: str, shuffle_buffer_size: int, batch_size: int, proportion_empty_prompts: float, get_sdxl_conditioning_images=None):
|
351 |
+
dataset = (
|
352 |
+
wds.WebDataset(
|
353 |
+
train_shards,
|
354 |
+
resampled=True,
|
355 |
+
handler=wds.ignore_and_continue,
|
356 |
+
)
|
357 |
+
.shuffle(shuffle_buffer_size)
|
358 |
+
.decode("pil", handler=wds.ignore_and_continue)
|
359 |
+
.rename(
|
360 |
+
image="jpg;png;jpeg;webp",
|
361 |
+
text="text;txt;caption",
|
362 |
+
metadata="json",
|
363 |
+
handler=wds.warn_and_continue,
|
364 |
+
)
|
365 |
+
.map(lambda d: make_sample(d, proportion_empty_prompts=proportion_empty_prompts, get_sdxl_conditioning_images=get_sdxl_conditioning_images))
|
366 |
+
.select(lambda sample: "conditioning_image" not in sample or sample["conditioning_image"] is not None)
|
367 |
+
)
|
368 |
+
|
369 |
+
dataset = dataset.batched(batch_size, partial=False, collation_fn=default_collate)
|
370 |
+
|
371 |
+
return dataset
|
372 |
+
|
373 |
+
|
374 |
+
@torch.no_grad()
|
375 |
+
def make_sample(d, proportion_empty_prompts, get_sdxl_conditioning_images=None):
|
376 |
+
image = d["image"]
|
377 |
+
metadata = d["metadata"]
|
378 |
+
|
379 |
+
if random.random() < proportion_empty_prompts:
|
380 |
+
text = ""
|
381 |
+
else:
|
382 |
+
text = d["text"]
|
383 |
+
|
384 |
+
c_top, c_left, _, _ = get_random_crop_params([image.height, image.width], [1024, 1024])
|
385 |
+
|
386 |
+
original_width = int(metadata.get("original_width", 0.0))
|
387 |
+
original_height = int(metadata.get("original_height", 0.0))
|
388 |
+
|
389 |
+
micro_conditioning = torch.tensor([original_width, original_height, c_top, c_left, 1024, 1024])
|
390 |
+
|
391 |
+
text_input_ids_one = sdxl_tokenize_one(text)
|
392 |
+
|
393 |
+
text_input_ids_two = sdxl_tokenize_two(text)
|
394 |
+
|
395 |
+
image = image.convert("RGB")
|
396 |
+
|
397 |
+
image = TF.resize(
|
398 |
+
image,
|
399 |
+
1024,
|
400 |
+
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
|
401 |
+
)
|
402 |
+
|
403 |
+
image = TF.crop(
|
404 |
+
image,
|
405 |
+
c_top,
|
406 |
+
c_left,
|
407 |
+
1024,
|
408 |
+
1024,
|
409 |
+
)
|
410 |
+
|
411 |
+
sample = {
|
412 |
+
"micro_conditioning": micro_conditioning,
|
413 |
+
"text_input_ids_one": text_input_ids_one,
|
414 |
+
"text_input_ids_two": text_input_ids_two,
|
415 |
+
"image": SDXLVae.input_pil_to_tensor(image),
|
416 |
+
}
|
417 |
+
|
418 |
+
if get_sdxl_conditioning_images is not None:
|
419 |
+
conditioning_images = get_sdxl_conditioning_images(image)
|
420 |
+
|
421 |
+
sample["conditioning_image"] = conditioning_images["conditioning_image"]
|
422 |
+
|
423 |
+
if conditioning_images["conditioning_image_mask"] is not None:
|
424 |
+
sample["conditioning_image_mask"] = conditioning_images["conditioning_image_mask"]
|
425 |
+
|
426 |
+
return sample
|
427 |
+
|
428 |
+
|
429 |
+
def get_random_crop_params(input_size: Tuple[int, int], output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
|
430 |
+
h, w = input_size
|
431 |
+
|
432 |
+
th, tw = output_size
|
433 |
+
|
434 |
+
if h < th or w < tw:
|
435 |
+
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
|
436 |
+
|
437 |
+
if w == tw and h == th:
|
438 |
+
return 0, 0, h, w
|
439 |
+
|
440 |
+
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
441 |
+
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
442 |
+
|
443 |
+
return i, j, th, tw
|
444 |
+
|
445 |
+
|
446 |
+
def get_sdxl_conditioning_images(image, adapter_type=None, controlnet_type=None, controlnet_variant=None, open_pose=None, conditioning_image_mask=None):
|
447 |
+
resolution = image.width
|
448 |
+
|
449 |
+
if adapter_type == "openpose":
|
450 |
+
conditioning_image = open_pose(image, detect_resolution=resolution, image_resolution=resolution, return_pil=False)
|
451 |
+
|
452 |
+
if (conditioning_image == 0).all():
|
453 |
+
return None, None
|
454 |
+
|
455 |
+
conditioning_image_as_pil = Image.fromarray(conditioning_image)
|
456 |
+
|
457 |
+
conditioning_image = TF.to_tensor(conditioning_image)
|
458 |
+
|
459 |
+
if controlnet_type == "canny":
|
460 |
+
import cv2
|
461 |
+
|
462 |
+
conditioning_image = np.array(image)
|
463 |
+
conditioning_image = cv2.Canny(conditioning_image, 100, 200)
|
464 |
+
conditioning_image = conditioning_image[:, :, None]
|
465 |
+
conditioning_image = np.concatenate([conditioning_image, conditioning_image, conditioning_image], axis=2)
|
466 |
+
|
467 |
+
conditioning_image_as_pil = Image.fromarray(conditioning_image)
|
468 |
+
|
469 |
+
conditioning_image = TF.to_tensor(conditioning_image)
|
470 |
+
|
471 |
+
if controlnet_type == "inpainting":
|
472 |
+
if conditioning_image_mask is None:
|
473 |
+
if random.random() <= 0.25:
|
474 |
+
conditioning_image_mask = np.ones((resolution, resolution), np.float32)
|
475 |
+
else:
|
476 |
+
conditioning_image_mask = random.choice([make_random_rectangle_mask, make_random_irregular_mask, make_outpainting_mask])(resolution, resolution)
|
477 |
+
|
478 |
+
conditioning_image_mask = torch.from_numpy(conditioning_image_mask)
|
479 |
+
|
480 |
+
conditioning_image_mask = conditioning_image_mask[None, :, :]
|
481 |
+
|
482 |
+
conditioning_image = TF.to_tensor(image)
|
483 |
+
|
484 |
+
if controlnet_variant == "pre_encoded_controlnet_cond":
|
485 |
+
# where mask is 1, zero out the pixels. Note that this requires mask to be concattenated
|
486 |
+
# with the mask so that the network knows the zeroed out pixels are from the mask and
|
487 |
+
# are not just zero in the original image
|
488 |
+
conditioning_image = conditioning_image * (conditioning_image_mask < 0.5)
|
489 |
+
|
490 |
+
conditioning_image_as_pil = TF.to_pil_image(conditioning_image)
|
491 |
+
|
492 |
+
conditioning_image = TF.normalize(conditioning_image, [0.5], [0.5])
|
493 |
+
else:
|
494 |
+
# Just zero out the pixels which will be masked
|
495 |
+
conditioning_image_as_pil = TF.to_pil_image(conditioning_image * (conditioning_image_mask < 0.5))
|
496 |
+
|
497 |
+
# where mask is set to 1, set to -1 "special" masked image pixel.
|
498 |
+
# -1 is outside of the 0-1 range that the controlnet normalized
|
499 |
+
# input is in.
|
500 |
+
conditioning_image = conditioning_image * (conditioning_image_mask < 0.5) + -1.0 * (conditioning_image_mask >= 0.5)
|
501 |
+
|
502 |
+
return dict(conditioning_image=conditioning_image, conditioning_image_mask=conditioning_image_mask, conditioning_image_as_pil=conditioning_image_as_pil)
|
503 |
+
|
504 |
+
|
505 |
+
# TODO: would be nice to just call a function from a tokenizers https://github.com/huggingface/tokenizers
|
506 |
+
# i.e. afaik tokenizing shouldn't require holding any state
|
507 |
+
|
508 |
+
tokenizer_one = CLIPTokenizerFast.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="tokenizer")
|
509 |
+
|
510 |
+
tokenizer_two = CLIPTokenizerFast.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="tokenizer_2")
|
511 |
+
|
512 |
+
|
513 |
+
def sdxl_tokenize_one(prompts):
|
514 |
+
return tokenizer_one(
|
515 |
+
prompts,
|
516 |
+
padding="max_length",
|
517 |
+
max_length=tokenizer_one.model_max_length,
|
518 |
+
truncation=True,
|
519 |
+
return_tensors="pt",
|
520 |
+
).input_ids[0]
|
521 |
+
|
522 |
+
|
523 |
+
def sdxl_tokenize_two(prompts):
|
524 |
+
return tokenizer_two(
|
525 |
+
prompts,
|
526 |
+
padding="max_length",
|
527 |
+
max_length=tokenizer_one.model_max_length,
|
528 |
+
truncation=True,
|
529 |
+
return_tensors="pt",
|
530 |
+
).input_ids[0]
|
531 |
+
|
532 |
+
|
533 |
+
def sdxl_text_conditioning(text_encoder_one, text_encoder_two, text_input_ids_one, text_input_ids_two):
|
534 |
+
prompt_embeds_1 = text_encoder_one(
|
535 |
+
text_input_ids_one,
|
536 |
+
output_hidden_states=True,
|
537 |
+
).hidden_states[-2]
|
538 |
+
|
539 |
+
prompt_embeds_1 = prompt_embeds_1.view(prompt_embeds_1.shape[0], prompt_embeds_1.shape[1], -1)
|
540 |
+
|
541 |
+
prompt_embeds_2 = text_encoder_two(
|
542 |
+
text_input_ids_two,
|
543 |
+
output_hidden_states=True,
|
544 |
+
)
|
545 |
+
|
546 |
+
pooled_encoder_hidden_states = prompt_embeds_2[0]
|
547 |
+
|
548 |
+
prompt_embeds_2 = prompt_embeds_2.hidden_states[-2]
|
549 |
+
|
550 |
+
prompt_embeds_2 = prompt_embeds_2.view(prompt_embeds_2.shape[0], prompt_embeds_2.shape[1], -1)
|
551 |
+
|
552 |
+
encoder_hidden_states = torch.cat((prompt_embeds_1, prompt_embeds_2), dim=-1)
|
553 |
+
|
554 |
+
return encoder_hidden_states, pooled_encoder_hidden_states
|
555 |
+
|
556 |
+
|
557 |
+
def make_random_rectangle_mask(
|
558 |
+
height,
|
559 |
+
width,
|
560 |
+
margin=10,
|
561 |
+
bbox_min_size=100,
|
562 |
+
bbox_max_size=512,
|
563 |
+
min_times=1,
|
564 |
+
max_times=2,
|
565 |
+
):
|
566 |
+
mask = np.zeros((height, width), np.float32)
|
567 |
+
|
568 |
+
bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2)
|
569 |
+
|
570 |
+
times = np.random.randint(min_times, max_times + 1)
|
571 |
+
|
572 |
+
for i in range(times):
|
573 |
+
box_width = np.random.randint(bbox_min_size, bbox_max_size)
|
574 |
+
box_height = np.random.randint(bbox_min_size, bbox_max_size)
|
575 |
+
|
576 |
+
start_x = np.random.randint(margin, width - margin - box_width + 1)
|
577 |
+
start_y = np.random.randint(margin, height - margin - box_height + 1)
|
578 |
+
|
579 |
+
mask[start_y : start_y + box_height, start_x : start_x + box_width] = 1
|
580 |
+
|
581 |
+
return mask
|
582 |
+
|
583 |
+
|
584 |
+
def make_random_irregular_mask(height, width, max_angle=4, max_len=60, max_width=256, min_times=1, max_times=2):
|
585 |
+
import cv2
|
586 |
+
|
587 |
+
mask = np.zeros((height, width), np.float32)
|
588 |
+
|
589 |
+
times = np.random.randint(min_times, max_times + 1)
|
590 |
+
|
591 |
+
for i in range(times):
|
592 |
+
start_x = np.random.randint(width)
|
593 |
+
start_y = np.random.randint(height)
|
594 |
+
|
595 |
+
for j in range(1 + np.random.randint(5)):
|
596 |
+
angle = 0.01 + np.random.randint(max_angle)
|
597 |
+
|
598 |
+
if i % 2 == 0:
|
599 |
+
angle = 2 * 3.1415926 - angle
|
600 |
+
|
601 |
+
length = 10 + np.random.randint(max_len)
|
602 |
+
|
603 |
+
brush_w = 5 + np.random.randint(max_width)
|
604 |
+
|
605 |
+
end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width)
|
606 |
+
end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height)
|
607 |
+
|
608 |
+
choice = random.randint(0, 2)
|
609 |
+
|
610 |
+
if choice == 0:
|
611 |
+
cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w)
|
612 |
+
elif choice == 1:
|
613 |
+
cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1.0, thickness=-1)
|
614 |
+
elif choice == 2:
|
615 |
+
radius = brush_w // 2
|
616 |
+
mask[
|
617 |
+
start_y - radius : start_y + radius,
|
618 |
+
start_x - radius : start_x + radius,
|
619 |
+
] = 1
|
620 |
+
else:
|
621 |
+
assert False
|
622 |
+
|
623 |
+
start_x, start_y = end_x, end_y
|
624 |
+
|
625 |
+
return mask
|
626 |
+
|
627 |
+
|
628 |
+
def make_outpainting_mask(height, width, probs=[0.5, 0.5, 0.5, 0.5]):
|
629 |
+
mask = np.zeros((height, width), np.float32)
|
630 |
+
at_least_one_mask_applied = False
|
631 |
+
|
632 |
+
coords = [
|
633 |
+
[(0, 0), (1, get_padding(height))],
|
634 |
+
[(0, 0), (get_padding(width), 1)],
|
635 |
+
[(0, 1 - get_padding(height)), (1, 1)],
|
636 |
+
[(1 - get_padding(width), 0), (1, 1)],
|
637 |
+
]
|
638 |
+
|
639 |
+
for pp, coord in zip(probs, coords):
|
640 |
+
if np.random.random() < pp:
|
641 |
+
at_least_one_mask_applied = True
|
642 |
+
mask = apply_padding(mask=mask, coord=coord)
|
643 |
+
|
644 |
+
if not at_least_one_mask_applied:
|
645 |
+
idx = np.random.choice(range(len(coords)), p=np.array(probs) / sum(probs))
|
646 |
+
mask = apply_padding(mask=mask, coord=coords[idx])
|
647 |
+
|
648 |
+
return mask
|
649 |
+
|
650 |
+
|
651 |
+
def get_padding(size, min_padding_percent=0.04, max_padding_percent=0.5):
|
652 |
+
n1 = int(min_padding_percent * size)
|
653 |
+
n2 = int(max_padding_percent * size)
|
654 |
+
return np.random.randint(n1, n2) / size
|
655 |
+
|
656 |
+
|
657 |
+
def apply_padding(mask, coord):
|
658 |
+
height, width = mask.shape
|
659 |
+
|
660 |
+
mask[
|
661 |
+
int(coord[0][0] * height) : int(coord[1][0] * height),
|
662 |
+
int(coord[0][1] * width) : int(coord[1][1] * width),
|
663 |
+
] = 1
|
664 |
+
|
665 |
+
return mask
|
666 |
+
|
667 |
+
|
668 |
+
@torch.no_grad()
|
669 |
+
def sdxl_diffusion_loop(
|
670 |
+
prompts,
|
671 |
+
unet,
|
672 |
+
text_encoder_one,
|
673 |
+
text_encoder_two,
|
674 |
+
images=None,
|
675 |
+
controlnet=None,
|
676 |
+
adapter=None,
|
677 |
+
sigmas=None,
|
678 |
+
timesteps=None,
|
679 |
+
x_T=None,
|
680 |
+
micro_conditioning=None,
|
681 |
+
guidance_scale=5.0,
|
682 |
+
generator=None,
|
683 |
+
negative_prompts=None,
|
684 |
+
diffusion_loop=euler_ode_solver_diffusion_loop,
|
685 |
+
):
|
686 |
+
if negative_prompts is None:
|
687 |
+
negative_prompts = [""] * len(prompts)
|
688 |
+
|
689 |
+
prompts += negative_prompts
|
690 |
+
|
691 |
+
encoder_hidden_states, pooled_encoder_hidden_states = sdxl_text_conditioning(
|
692 |
+
text_encoder_one,
|
693 |
+
text_encoder_two,
|
694 |
+
sdxl_tokenize_one(prompts).to(text_encoder_one.device),
|
695 |
+
sdxl_tokenize_two(prompts).to(text_encoder_two.device),
|
696 |
+
)
|
697 |
+
|
698 |
+
if x_T is None:
|
699 |
+
x_T = torch.randn((1, 4, 1024 // 8, 1024 // 8), dtype=torch.float32, device=unet.device, generator=generator)
|
700 |
+
x_T = x_T * ((sigmas.max() ** 2 + 1) ** 0.5)
|
701 |
+
|
702 |
+
if sigmas is None:
|
703 |
+
sigmas = make_sigmas(device=unet.device)
|
704 |
+
|
705 |
+
if timesteps is None:
|
706 |
+
timesteps = torch.linspace(0, sigmas.numel(), 50, dtype=torch.long, device=unet.device)
|
707 |
+
|
708 |
+
if micro_conditioning is None:
|
709 |
+
micro_conditioning = torch.tensor([1024, 1024, 0, 0, 1024, 1024], dtype=torch.long, device=unet.device)
|
710 |
+
|
711 |
+
if adapter is not None:
|
712 |
+
down_block_additional_residuals = adapter(images)
|
713 |
+
else:
|
714 |
+
down_block_additional_residuals = None
|
715 |
+
|
716 |
+
if controlnet is not None:
|
717 |
+
controlnet_cond = images
|
718 |
+
else:
|
719 |
+
controlnet_cond = None
|
720 |
+
|
721 |
+
eps_theta = lambda x_t, t, sigma: sdxl_eps_theta(
|
722 |
+
x_t=x_t,
|
723 |
+
t=t,
|
724 |
+
sigma=sigma,
|
725 |
+
unet=unet,
|
726 |
+
encoder_hidden_states=encoder_hidden_states,
|
727 |
+
pooled_encoder_hidden_states=pooled_encoder_hidden_states,
|
728 |
+
micro_conditioning=micro_conditioning,
|
729 |
+
guidance_scale=guidance_scale,
|
730 |
+
controlnet=controlnet,
|
731 |
+
controlnet_cond=controlnet_cond,
|
732 |
+
down_block_additional_residuals=down_block_additional_residuals,
|
733 |
+
)
|
734 |
+
|
735 |
+
x_0 = diffusion_loop(eps_theta=eps_theta, timesteps=timesteps, sigmas=sigmas, x_T=x_T)
|
736 |
+
|
737 |
+
return x_0
|
738 |
+
|
739 |
+
|
740 |
+
@torch.no_grad()
|
741 |
+
def sdxl_eps_theta(
|
742 |
+
x_t,
|
743 |
+
t,
|
744 |
+
sigma,
|
745 |
+
unet,
|
746 |
+
encoder_hidden_states,
|
747 |
+
pooled_encoder_hidden_states,
|
748 |
+
micro_conditioning,
|
749 |
+
guidance_scale,
|
750 |
+
controlnet=None,
|
751 |
+
controlnet_cond=None,
|
752 |
+
down_block_additional_residuals=None,
|
753 |
+
):
|
754 |
+
# TODO - how does this not effect the ode we are solving
|
755 |
+
scaled_x_t = x_t / ((sigma**2 + 1) ** 0.5)
|
756 |
+
|
757 |
+
if guidance_scale > 1.0:
|
758 |
+
scaled_x_t = torch.concat([scaled_x_t, scaled_x_t])
|
759 |
+
|
760 |
+
if controlnet is not None:
|
761 |
+
controlnet_out = controlnet(
|
762 |
+
x_t=scaled_x_t,
|
763 |
+
t=t,
|
764 |
+
encoder_hidden_states=encoder_hidden_states,
|
765 |
+
micro_conditioning=micro_conditioning,
|
766 |
+
pooled_encoder_hidden_states=pooled_encoder_hidden_states,
|
767 |
+
controlnet_cond=controlnet_cond,
|
768 |
+
)
|
769 |
+
|
770 |
+
down_block_additional_residuals = controlnet_out["down_block_res_samples"]
|
771 |
+
mid_block_additional_residual = controlnet_out["mid_block_res_sample"]
|
772 |
+
add_to_down_block_inputs = controlnet_out.get("add_to_down_block_inputs", None)
|
773 |
+
add_to_output = controlnet_out.get("add_to_output", None)
|
774 |
+
else:
|
775 |
+
mid_block_additional_residual = None
|
776 |
+
add_to_down_block_inputs = None
|
777 |
+
add_to_output = None
|
778 |
+
|
779 |
+
eps_hat = unet(
|
780 |
+
x_t=scaled_x_t,
|
781 |
+
t=t,
|
782 |
+
encoder_hidden_states=encoder_hidden_states,
|
783 |
+
micro_conditioning=micro_conditioning,
|
784 |
+
pooled_encoder_hidden_states=pooled_encoder_hidden_states,
|
785 |
+
down_block_additional_residuals=down_block_additional_residuals,
|
786 |
+
mid_block_additional_residual=mid_block_additional_residual,
|
787 |
+
add_to_down_block_inputs=add_to_down_block_inputs,
|
788 |
+
add_to_output=add_to_output,
|
789 |
+
)
|
790 |
+
|
791 |
+
if guidance_scale > 1.0:
|
792 |
+
eps_hat_uncond, eps_hat = eps_hat.chunk(2)
|
793 |
+
|
794 |
+
eps_hat = eps_hat_uncond + guidance_scale * (eps_hat - eps_hat_uncond)
|
795 |
+
|
796 |
+
return eps_hat
|
797 |
+
|
798 |
+
known_negative_prompt = "text, watermark, low-quality, signature, moiré pattern, downsampling, aliasing, distorted, blurry, glossy, blur, jpeg artifacts, compression artifacts, poorly drawn, low-resolution, bad, distortion, twisted, excessive, exaggerated pose, exaggerated limbs, grainy, symmetrical, duplicate, error, pattern, beginner, pixelated, fake, hyper, glitch, overexposed, high-contrast, bad-contrast"
|
799 |
+
|
800 |
+
def gen_sdxl_simplified_interface(
|
801 |
+
prompt:str,
|
802 |
+
negative_prompt: Optional[str] = None,
|
803 |
+
controlnet_checkpoint: Optional[str]=None,
|
804 |
+
controlnet: Optional[Literal["SDXLControlNet", "SDXLContolNetFull", "SDXLControlNetPreEncodedControlnetCond"]]=None,
|
805 |
+
adapter_checkpoint: Optional[str]=None,
|
806 |
+
num_inference_steps=50,
|
807 |
+
images=None,
|
808 |
+
masks=None,
|
809 |
+
apply_conditioning: Optional[Literal["canny"]]=None,
|
810 |
+
num_images: int=1,
|
811 |
+
device: Optional[str]=None,
|
812 |
+
text_encoder_one=None,
|
813 |
+
text_encoder_two=None,
|
814 |
+
unet=None,
|
815 |
+
vae=None,
|
816 |
+
):
|
817 |
+
if device is None:
|
818 |
+
if torch.cuda.is_available():
|
819 |
+
device = "cuda"
|
820 |
+
elif torch.backends.mps.is_available():
|
821 |
+
device = "mps"
|
822 |
+
|
823 |
+
if text_encoder_one is None:
|
824 |
+
text_encoder_one = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", variant="fp16", torch_dtype=torch.float16)
|
825 |
+
text_encoder_one.to(device=device)
|
826 |
+
|
827 |
+
if text_encoder_two is None:
|
828 |
+
text_encoder_two = CLIPTextModelWithProjection.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2", variant="fp16", torch_dtype=torch.float16)
|
829 |
+
text_encoder_two.to(device=device)
|
830 |
+
|
831 |
+
if vae is None:
|
832 |
+
vae = SDXLVae.load_fp16_fix(device=device)
|
833 |
+
|
834 |
+
if unet is None:
|
835 |
+
unet = SDXLUNet.load_fp16(device=device)
|
836 |
+
|
837 |
+
if isinstance(controlnet, str) and controlnet_checkpoint is not None:
|
838 |
+
if controlnet == "SDXLControlNet":
|
839 |
+
controlnet = SDXLControlNet.load(controlnet_checkpoint, device=device, dtype=torch.float16)
|
840 |
+
elif controlnet == "SDXLControlNetFull":
|
841 |
+
controlnet = SDXLControlNetFull.load(controlnet_checkpoint, device=device, dtype=torch.float16)
|
842 |
+
elif controlnet == "SDXLControlNetPreEncodedControlnetCond":
|
843 |
+
controlnet = SDXLControlNetPreEncodedControlnetCond.load(controlnet_checkpoint, device=device, dtype=torch.float16)
|
844 |
+
else:
|
845 |
+
assert False
|
846 |
+
|
847 |
+
if adapter_checkpoint is not None:
|
848 |
+
adapter = SDXLAdapter.load(adapter_checkpoint, device=device, dtype=torch.float16)
|
849 |
+
else:
|
850 |
+
adapter = None
|
851 |
+
|
852 |
+
sigmas = make_sigmas()
|
853 |
+
|
854 |
+
timesteps = torch.linspace(0, sigmas.numel(), num_inference_steps, dtype=torch.long, device=unet.device)
|
855 |
+
|
856 |
+
if images is not None:
|
857 |
+
if not isinstance(images, list):
|
858 |
+
images = [images]
|
859 |
+
|
860 |
+
if masks is not None and not isinstance(masks, list):
|
861 |
+
masks = [masks]
|
862 |
+
|
863 |
+
images_ = []
|
864 |
+
|
865 |
+
for image_idx, image in enumerate(images):
|
866 |
+
if isinstance(image, str):
|
867 |
+
image = Image.open(image)
|
868 |
+
image = image.convert("RGB")
|
869 |
+
image = image.resize((1024, 1024))
|
870 |
+
elif isinstance(image, Image.Image):
|
871 |
+
...
|
872 |
+
else:
|
873 |
+
assert False
|
874 |
+
|
875 |
+
if apply_conditioning == "canny":
|
876 |
+
import cv2
|
877 |
+
|
878 |
+
image = np.array(image)
|
879 |
+
image = cv2.Canny(image, 100, 200)
|
880 |
+
image = image[:, :, None]
|
881 |
+
controlnet_image = np.concatenate([controlnet_image, controlnet_image, controlnet_image], axis=2)
|
882 |
+
|
883 |
+
image = TF.to_tensor(image)
|
884 |
+
|
885 |
+
if masks is not None:
|
886 |
+
mask = masks[image_idx]
|
887 |
+
if isinstance(mask, str):
|
888 |
+
mask = Image.open(mask)
|
889 |
+
mask = mask.convert("L")
|
890 |
+
mask = mask.resize((1024, 1024))
|
891 |
+
elif isinstance(mask, Image.Image):
|
892 |
+
...
|
893 |
+
else:
|
894 |
+
assert False
|
895 |
+
mask = TF.to_tensor(mask)
|
896 |
+
|
897 |
+
if controlnet == "SDXLControlNetPreEncodedControlnetCond":
|
898 |
+
image = image * (mask < 0.5)
|
899 |
+
image = TF.normalized(image, [0.5], [0.5])
|
900 |
+
image = vae.encode(image)
|
901 |
+
mask = TF.resize(mask, (1024 // 8, 1024 // 8))
|
902 |
+
image = torch.concat((image, mask))
|
903 |
+
else:
|
904 |
+
image = image * (mask < 0.5) + -1.0 * (mask >= 0.5)
|
905 |
+
|
906 |
+
images_.append(image)
|
907 |
+
|
908 |
+
images_ = torch.concat(images_)
|
909 |
+
else:
|
910 |
+
images_ = None
|
911 |
+
|
912 |
+
x_0 = sdxl_diffusion_loop(
|
913 |
+
prompts=[prompt] * num_images,
|
914 |
+
negative_prompts=[negative_prompt] * num_images,
|
915 |
+
unet=unet,
|
916 |
+
text_encoder_one=text_encoder_one,
|
917 |
+
text_encoder_two=text_encoder_two,
|
918 |
+
sigmas=sigmas,
|
919 |
+
timesteps=timesteps,
|
920 |
+
controlnet=controlnet,
|
921 |
+
adapter=adapter,
|
922 |
+
images=images_,
|
923 |
+
)
|
924 |
+
|
925 |
+
x_0 = vae.decode(x_0)
|
926 |
+
x_0 = vae.output_tensor_to_pil(x_0)
|
927 |
+
|
928 |
+
return x_0
|
929 |
+
|
930 |
+
|
931 |
+
if __name__ == "__main__":
|
932 |
+
from argparse import ArgumentParser
|
933 |
+
|
934 |
+
args = ArgumentParser()
|
935 |
+
args.add_argument("--prompt", required=True, type=str)
|
936 |
+
args.add_argument("--num_images", required=True, type=int, default=1)
|
937 |
+
args.add_argument("--num_inference_steps", required=False, type=int, default=50)
|
938 |
+
args.add_argument("--image", required=False, type=str, default=None)
|
939 |
+
args.add_argument("--mask", required=False, type=str, default=None)
|
940 |
+
args.add_argument("--controlnet_checkpoint", required=False, type=str, default=None)
|
941 |
+
args.add_argument("--controlnet", required=False, choices=["SDXLControlNet", "SDXLControlNetFull", "SDXLControNetPreEncodedControlnetCond"], default=None)
|
942 |
+
args.add_argument("--adapter_checkpoint", required=False, type=str, default=None)
|
943 |
+
args.add_argument("--apply_conditioning", choices=["canny"], required=False, default=None)
|
944 |
+
args.add_argument("--device", required=False, default=None)
|
945 |
+
args = args.parse_args()
|
946 |
+
|
947 |
+
images = gen_sdxl_simplified_interface(
|
948 |
+
prompt=args.prompt,
|
949 |
+
num_images=args.num_images,
|
950 |
+
num_inference_steps=args.num_inference_steps,
|
951 |
+
images=[args.image],
|
952 |
+
masks=[args.mask],
|
953 |
+
controlnet_checkpoint=args.controlnet_checkpoint,
|
954 |
+
controlnet=args.controlnet,
|
955 |
+
adapter_checkpoint=args.adapter_checkpoint,
|
956 |
+
apply_conditioning=args.apply_conditioning,
|
957 |
+
device=args.device,
|
958 |
+
negative_prompt=known_negative_prompt,
|
959 |
+
)
|
960 |
+
|
961 |
+
for i, image in enumerate(images):
|
962 |
+
image.save(f"out_{i}.png")
|
sdxl_models.py
ADDED
@@ -0,0 +1,1375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import safetensors.torch
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchvision.transforms.functional as TF
|
9 |
+
import xformers
|
10 |
+
from PIL import Image
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
class ModelUtils:
|
15 |
+
@property
|
16 |
+
def dtype(self):
|
17 |
+
return next(self.parameters()).dtype
|
18 |
+
|
19 |
+
@property
|
20 |
+
def device(self):
|
21 |
+
return next(self.parameters()).device
|
22 |
+
|
23 |
+
@classmethod
|
24 |
+
def load(cls, load_from: str, device, overrides: Optional[List[str]] = None):
|
25 |
+
import load_state_dict_patch
|
26 |
+
|
27 |
+
load_from = [load_from]
|
28 |
+
|
29 |
+
load_from += overrides
|
30 |
+
|
31 |
+
state_dict = {}
|
32 |
+
|
33 |
+
for load_from_ in load_from:
|
34 |
+
if os.path.isdir(load_from_):
|
35 |
+
load_from_ = os.path.join(load_from_, "diffusion_pytorch_model.safetensors")
|
36 |
+
|
37 |
+
state_dict.update(safetensors.torch.load_file(load_from_, device=device))
|
38 |
+
|
39 |
+
with torch.device("meta"):
|
40 |
+
model = cls()
|
41 |
+
|
42 |
+
model.load_state_dict(state_dict, assign=True)
|
43 |
+
|
44 |
+
return model
|
45 |
+
|
46 |
+
|
47 |
+
vae_scaling_factor = 0.13025
|
48 |
+
|
49 |
+
|
50 |
+
class SDXLVae(nn.Module, ModelUtils):
|
51 |
+
def __init__(self):
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
# fmt: off
|
55 |
+
|
56 |
+
self.encoder = nn.ModuleDict(dict(
|
57 |
+
# 3 -> 128
|
58 |
+
conv_in=nn.Conv2d(3, 128, kernel_size=3, padding=1),
|
59 |
+
|
60 |
+
down_blocks=nn.ModuleList([
|
61 |
+
# 128 -> 128
|
62 |
+
nn.ModuleDict(dict(
|
63 |
+
resnets=nn.ModuleList([ResnetBlock2D(128, 128, eps=1e-6), ResnetBlock2D(128, 128, eps=1e-6)]),
|
64 |
+
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)))]),
|
65 |
+
)),
|
66 |
+
# 128 -> 256
|
67 |
+
nn.ModuleDict(dict(
|
68 |
+
resnets=nn.ModuleList([ResnetBlock2D(128, 256, eps=1e-6), ResnetBlock2D(256, 256, eps=1e-6)]),
|
69 |
+
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)))]),
|
70 |
+
)),
|
71 |
+
# 256 -> 512
|
72 |
+
nn.ModuleDict(dict(
|
73 |
+
resnets=nn.ModuleList([ResnetBlock2D(256, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
|
74 |
+
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1)))]),
|
75 |
+
)),
|
76 |
+
# 512 -> 512
|
77 |
+
nn.ModuleDict(dict(resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]))),
|
78 |
+
]),
|
79 |
+
|
80 |
+
# 512 -> 512
|
81 |
+
mid_block=nn.ModuleDict(dict(
|
82 |
+
attentions=nn.ModuleList([Attention(512, 512, qkv_bias=True)]),
|
83 |
+
resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
|
84 |
+
)),
|
85 |
+
|
86 |
+
# 512 -> 8
|
87 |
+
conv_norm_out=nn.GroupNorm(32, 512, eps=1e-06),
|
88 |
+
conv_act=nn.SiLU(),
|
89 |
+
conv_out=nn.Conv2d(512, 8, kernel_size=3, padding=1)
|
90 |
+
))
|
91 |
+
|
92 |
+
# 8 -> 8
|
93 |
+
self.quant_conv = nn.Conv2d(8, 8, kernel_size=1)
|
94 |
+
|
95 |
+
# 8 -> 4 from sampling mean and std
|
96 |
+
|
97 |
+
# 4 -> 4
|
98 |
+
self.post_quant_conv = nn.Conv2d(4, 4, 1)
|
99 |
+
|
100 |
+
self.decoder = nn.ModuleDict(dict(
|
101 |
+
# 4 -> 512
|
102 |
+
conv_in=nn.Conv2d(4, 512, kernel_size=3, padding=1),
|
103 |
+
|
104 |
+
# 512 -> 512
|
105 |
+
mid_block=nn.ModuleDict(dict(
|
106 |
+
attentions=nn.ModuleList([Attention(512, 512, qkv_bias=True)]),
|
107 |
+
resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
|
108 |
+
)),
|
109 |
+
|
110 |
+
up_blocks=nn.ModuleList([
|
111 |
+
# 512 -> 512
|
112 |
+
nn.ModuleDict(dict(
|
113 |
+
resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
|
114 |
+
upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)))]),
|
115 |
+
)),
|
116 |
+
|
117 |
+
# 512 -> 512
|
118 |
+
nn.ModuleDict(dict(
|
119 |
+
resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
|
120 |
+
upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)))]),
|
121 |
+
)),
|
122 |
+
|
123 |
+
# 512 -> 256
|
124 |
+
nn.ModuleDict(dict(
|
125 |
+
resnets=nn.ModuleList([ResnetBlock2D(512, 256, eps=1e-6), ResnetBlock2D(256, 256, eps=1e-6), ResnetBlock2D(256, 256, eps=1e-6)]),
|
126 |
+
upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)))]),
|
127 |
+
)),
|
128 |
+
|
129 |
+
# 256 -> 128
|
130 |
+
nn.ModuleDict(dict(
|
131 |
+
resnets=nn.ModuleList([ResnetBlock2D(256, 128, eps=1e-6), ResnetBlock2D(128, 128, eps=1e-6), ResnetBlock2D(128, 128, eps=1e-6)]),
|
132 |
+
)),
|
133 |
+
]),
|
134 |
+
|
135 |
+
# 128 -> 3
|
136 |
+
conv_norm_out=nn.GroupNorm(32, 128, eps=1e-06),
|
137 |
+
conv_act=nn.SiLU(),
|
138 |
+
conv_out=nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
139 |
+
))
|
140 |
+
|
141 |
+
# fmt: on
|
142 |
+
|
143 |
+
def encode(self, x, generator=None):
|
144 |
+
h = x
|
145 |
+
|
146 |
+
h = self.encoder["conv_in"](h)
|
147 |
+
|
148 |
+
for down_block in self.encoder["down_blocks"]:
|
149 |
+
for resnet in down_block["resnets"]:
|
150 |
+
h = resnet(h)
|
151 |
+
|
152 |
+
if "downsamplers" in down_block:
|
153 |
+
h = down_block["downsamplers"][0]["conv"](h)
|
154 |
+
|
155 |
+
h = self.encoder["mid_block"]["resnets"][0](h)
|
156 |
+
h = self.encoder["mid_block"]["attentions"][0](h)
|
157 |
+
h = self.encoder["mid_block"]["resnets"][1](h)
|
158 |
+
|
159 |
+
h = self.encoder["conv_norm_out"](h)
|
160 |
+
h = self.encoder["conv_act"](h)
|
161 |
+
h = self.encoder["conv_out"](h)
|
162 |
+
|
163 |
+
mean, logvar = self.quant_conv(h).chunk(2, dim=1)
|
164 |
+
|
165 |
+
logvar = torch.clamp(logvar, -30.0, 20.0)
|
166 |
+
|
167 |
+
std = torch.exp(0.5 * logvar)
|
168 |
+
|
169 |
+
z = mean + torch.randn(mean.shape, device=mean.device, dtype=mean.dtype, generator=generator) * std
|
170 |
+
|
171 |
+
z = z * vae_scaling_factor
|
172 |
+
|
173 |
+
return z
|
174 |
+
|
175 |
+
def decode(self, z):
|
176 |
+
z = z / vae_scaling_factor
|
177 |
+
|
178 |
+
h = z
|
179 |
+
|
180 |
+
h = self.post_quant_conv(h)
|
181 |
+
|
182 |
+
h = self.decoder["mid_block"]["resnets"][0](h)
|
183 |
+
h = self.decoder["mid_block"]["attentions"][0](h)
|
184 |
+
h = self.decoder["mid_block"]["resnets"][1](h)
|
185 |
+
|
186 |
+
for up_block in self.encoder["up_blocks"]:
|
187 |
+
for resnet in up_block["resnets"]:
|
188 |
+
h = resnet(h)
|
189 |
+
|
190 |
+
if "upsamplers" in up_block:
|
191 |
+
h = up_block["upsamplers"][0]["conv"](h)
|
192 |
+
|
193 |
+
h = self.decoder["conv_norm_out"](h)
|
194 |
+
h = self.decoder["conv_act"](h)
|
195 |
+
h = self.decoder["conv_out"](h)
|
196 |
+
|
197 |
+
x_pred = h
|
198 |
+
|
199 |
+
return x_pred
|
200 |
+
|
201 |
+
@classmethod
|
202 |
+
def input_pil_to_tensor(self, x):
|
203 |
+
x = TF.to_tensor(x)
|
204 |
+
x = TF.normalize(x, [0.5], [0.5])
|
205 |
+
if x.ndim == 3:
|
206 |
+
x = x[None, :, :, :]
|
207 |
+
return x
|
208 |
+
|
209 |
+
@classmethod
|
210 |
+
def output_tensor_to_pil(self, x_pred):
|
211 |
+
x_pred = ((x_pred * 0.5 + 0.5).clamp(0, 1) * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
212 |
+
|
213 |
+
x_pred = x_pred.permute(0, 2, 3, 1).cpu().numpy()
|
214 |
+
|
215 |
+
x_pred = [Image.fromarray(x) for x in x_pred]
|
216 |
+
|
217 |
+
return x_pred
|
218 |
+
|
219 |
+
@classmethod
|
220 |
+
def load_fp32(cls, device=None, overrides=None):
|
221 |
+
return cls.load("./weights/sdxl_vae.safetensors", device=device, overrides=overrides)
|
222 |
+
|
223 |
+
@classmethod
|
224 |
+
def load_fp16(cls, device=None, overrides=None):
|
225 |
+
return cls.load("./weights/sdxl_vae.fp16.safetensors", device=device, overrides=overrides)
|
226 |
+
|
227 |
+
@classmethod
|
228 |
+
def load_fp16_fix(cls, device=None, overrides=None):
|
229 |
+
return cls.load("./weights/sdxl_vae_fp16_fix.safetensors", device=device, overrides=overrides)
|
230 |
+
|
231 |
+
|
232 |
+
class SDXLUNet(nn.Module, ModelUtils):
|
233 |
+
def __init__(self):
|
234 |
+
super().__init__()
|
235 |
+
|
236 |
+
# fmt: off
|
237 |
+
|
238 |
+
encoder_hidden_states_dim = 2048
|
239 |
+
|
240 |
+
# timesteps embedding:
|
241 |
+
|
242 |
+
time_sinusoidal_embedding_dim = 320
|
243 |
+
time_embedding_dim = 1280
|
244 |
+
|
245 |
+
self.get_sinusoidal_timestep_embedding = lambda timesteps: get_sinusoidal_embedding(timesteps, time_sinusoidal_embedding_dim)
|
246 |
+
|
247 |
+
self.time_embedding = nn.ModuleDict(dict(
|
248 |
+
linear_1=nn.Linear(time_sinusoidal_embedding_dim, time_embedding_dim),
|
249 |
+
act=nn.SiLU(),
|
250 |
+
linear_2=nn.Linear(time_embedding_dim, time_embedding_dim),
|
251 |
+
))
|
252 |
+
|
253 |
+
# image size and crop coordinates conditioning embedding (i.e. micro conditioning):
|
254 |
+
|
255 |
+
num_micro_conditioning_values = 6
|
256 |
+
micro_conditioning_embedding_dim = 256
|
257 |
+
additional_embedding_encoder_dim = 1280
|
258 |
+
self.get_sinusoidal_micro_conditioning_embedding = lambda micro_conditioning: get_sinusoidal_embedding(micro_conditioning, micro_conditioning_embedding_dim)
|
259 |
+
|
260 |
+
self.add_embedding = nn.ModuleDict(dict(
|
261 |
+
linear_1=nn.Linear(additional_embedding_encoder_dim + num_micro_conditioning_values * micro_conditioning_embedding_dim, time_embedding_dim),
|
262 |
+
act=nn.SiLU(),
|
263 |
+
linear_2=nn.Linear(time_embedding_dim, time_embedding_dim),
|
264 |
+
))
|
265 |
+
|
266 |
+
# actual unet blocks:
|
267 |
+
|
268 |
+
self.conv_in = nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
269 |
+
|
270 |
+
self.down_blocks = nn.ModuleList([
|
271 |
+
# 320 -> 320
|
272 |
+
nn.ModuleDict(dict(
|
273 |
+
resnets=nn.ModuleList([
|
274 |
+
ResnetBlock2D(320, 320, time_embedding_dim),
|
275 |
+
ResnetBlock2D(320, 320, time_embedding_dim),
|
276 |
+
]),
|
277 |
+
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)))]),
|
278 |
+
)),
|
279 |
+
# 320 -> 640
|
280 |
+
nn.ModuleDict(dict(
|
281 |
+
resnets=nn.ModuleList([
|
282 |
+
ResnetBlock2D(320, 640, time_embedding_dim),
|
283 |
+
ResnetBlock2D(640, 640, time_embedding_dim),
|
284 |
+
]),
|
285 |
+
attentions=nn.ModuleList([
|
286 |
+
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
|
287 |
+
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
|
288 |
+
]),
|
289 |
+
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)))]),
|
290 |
+
)),
|
291 |
+
# 640 -> 1280
|
292 |
+
nn.ModuleDict(dict(
|
293 |
+
resnets=nn.ModuleList([
|
294 |
+
ResnetBlock2D(640, 1280, time_embedding_dim),
|
295 |
+
ResnetBlock2D(1280, 1280, time_embedding_dim),
|
296 |
+
]),
|
297 |
+
attentions=nn.ModuleList([
|
298 |
+
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
|
299 |
+
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
|
300 |
+
]),
|
301 |
+
)),
|
302 |
+
])
|
303 |
+
|
304 |
+
self.mid_block = nn.ModuleDict(dict(
|
305 |
+
resnets=nn.ModuleList([
|
306 |
+
ResnetBlock2D(1280, 1280, time_embedding_dim),
|
307 |
+
ResnetBlock2D(1280, 1280, time_embedding_dim),
|
308 |
+
]),
|
309 |
+
attentions=nn.ModuleList([TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10)]),
|
310 |
+
))
|
311 |
+
|
312 |
+
self.up_blocks = nn.ModuleList([
|
313 |
+
# 1280 -> 1280
|
314 |
+
nn.ModuleDict(dict(
|
315 |
+
resnets=nn.ModuleList([
|
316 |
+
ResnetBlock2D(1280 + 1280, 1280, time_embedding_dim),
|
317 |
+
ResnetBlock2D(1280 + 1280, 1280, time_embedding_dim),
|
318 |
+
ResnetBlock2D(1280 + 640, 1280, time_embedding_dim),
|
319 |
+
]),
|
320 |
+
attentions=nn.ModuleList([
|
321 |
+
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
|
322 |
+
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
|
323 |
+
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
|
324 |
+
]),
|
325 |
+
upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(1280, 1280, kernel_size=3, padding=1)))]),
|
326 |
+
)),
|
327 |
+
# 1280 -> 640
|
328 |
+
nn.ModuleDict(dict(
|
329 |
+
resnets=nn.ModuleList([
|
330 |
+
ResnetBlock2D(1280 + 640, 640, time_embedding_dim),
|
331 |
+
ResnetBlock2D(640 + 640, 640, time_embedding_dim),
|
332 |
+
ResnetBlock2D(640 + 320, 640, time_embedding_dim),
|
333 |
+
]),
|
334 |
+
attentions=nn.ModuleList([
|
335 |
+
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
|
336 |
+
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
|
337 |
+
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
|
338 |
+
]),
|
339 |
+
upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, padding=1)))]),
|
340 |
+
)),
|
341 |
+
# 640 -> 320
|
342 |
+
nn.ModuleDict(dict(
|
343 |
+
resnets=nn.ModuleList([
|
344 |
+
ResnetBlock2D(640 + 320, 320, time_embedding_dim),
|
345 |
+
ResnetBlock2D(320 + 320, 320, time_embedding_dim),
|
346 |
+
ResnetBlock2D(320 + 320, 320, time_embedding_dim),
|
347 |
+
]),
|
348 |
+
))
|
349 |
+
])
|
350 |
+
|
351 |
+
self.conv_norm_out = nn.GroupNorm(32, 320)
|
352 |
+
self.conv_act = nn.SiLU()
|
353 |
+
self.conv_out = nn.Conv2d(320, 4, kernel_size=3, padding=1)
|
354 |
+
|
355 |
+
# fmt: on
|
356 |
+
|
357 |
+
def forward(
|
358 |
+
self,
|
359 |
+
x_t,
|
360 |
+
t,
|
361 |
+
encoder_hidden_states,
|
362 |
+
micro_conditioning,
|
363 |
+
pooled_encoder_hidden_states,
|
364 |
+
down_block_additional_residuals: Optional[List[torch.Tensor]] = None,
|
365 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
366 |
+
add_to_down_block_inputs: Optional[List[torch.Tensor]] = None,
|
367 |
+
add_to_output: Optional[torch.Tensor] = None,
|
368 |
+
):
|
369 |
+
hidden_state = x_t
|
370 |
+
|
371 |
+
t = self.get_sinusoidal_timestep_embedding(t)
|
372 |
+
t = t.to(dtype=hidden_state.dtype)
|
373 |
+
t = self.time_embedding["linear_1"](t)
|
374 |
+
t = self.time_embedding["act"](t)
|
375 |
+
t = self.time_embedding["linear_2"](t)
|
376 |
+
|
377 |
+
additional_conditioning = self.get_sinusoidal_micro_conditioning_embedding(micro_conditioning)
|
378 |
+
additional_conditioning = additional_conditioning.to(dtype=hidden_state.dtype)
|
379 |
+
additional_conditioning = additional_conditioning.flatten(1)
|
380 |
+
additional_conditioning = torch.concat([pooled_encoder_hidden_states, additional_conditioning], dim=-1)
|
381 |
+
additional_conditioning = self.add_embedding["linear_1"](additional_conditioning)
|
382 |
+
additional_conditioning = self.add_embedding["act"](additional_conditioning)
|
383 |
+
additional_conditioning = self.add_embedding["linear_2"](additional_conditioning)
|
384 |
+
|
385 |
+
t = t + additional_conditioning
|
386 |
+
|
387 |
+
hidden_state = self.conv_in(hidden_state)
|
388 |
+
|
389 |
+
residuals = [hidden_state]
|
390 |
+
|
391 |
+
for down_block in self.down_blocks:
|
392 |
+
for i, resnet in enumerate(down_block["resnets"]):
|
393 |
+
if add_to_down_block_inputs is not None:
|
394 |
+
hidden_state = hidden_state + add_to_down_block_inputs.pop(0)
|
395 |
+
|
396 |
+
hidden_state = resnet(hidden_state, t)
|
397 |
+
|
398 |
+
if "attentions" in down_block:
|
399 |
+
hidden_state = down_block["attentions"][i](hidden_state, encoder_hidden_states)
|
400 |
+
|
401 |
+
residuals.append(hidden_state)
|
402 |
+
|
403 |
+
if "downsamplers" in down_block:
|
404 |
+
if add_to_down_block_inputs is not None:
|
405 |
+
hidden_state = hidden_state + add_to_down_block_inputs.pop(0)
|
406 |
+
|
407 |
+
hidden_state = down_block["downsamplers"][0]["conv"](hidden_state)
|
408 |
+
|
409 |
+
residuals.append(hidden_state)
|
410 |
+
|
411 |
+
hidden_state = self.mid_block["resnets"][0](hidden_state, t)
|
412 |
+
hidden_state = self.mid_block["attentions"][0](hidden_state, encoder_hidden_states)
|
413 |
+
hidden_state = self.mid_block["resnets"][1](hidden_state, t)
|
414 |
+
|
415 |
+
if mid_block_additional_residual is not None:
|
416 |
+
hidden_state = hidden_state + mid_block_additional_residual
|
417 |
+
|
418 |
+
for up_block in self.up_blocks:
|
419 |
+
for i, resnet in enumerate(up_block["resnets"]):
|
420 |
+
residual = residuals.pop()
|
421 |
+
|
422 |
+
if down_block_additional_residuals is not None:
|
423 |
+
residual = residual + down_block_additional_residuals.pop()
|
424 |
+
|
425 |
+
hidden_state = torch.concat([hidden_state, residual], dim=1)
|
426 |
+
|
427 |
+
hidden_state = resnet(hidden_state, t)
|
428 |
+
|
429 |
+
if "attentions" in up_block:
|
430 |
+
hidden_state = up_block["attentions"][i](hidden_state, encoder_hidden_states)
|
431 |
+
|
432 |
+
if "upsamplers" in up_block:
|
433 |
+
hidden_state = F.interpolate(hidden_state, scale_factor=2.0, mode="nearest")
|
434 |
+
hidden_state = up_block["upsamplers"][0]["conv"](hidden_state)
|
435 |
+
|
436 |
+
hidden_state = self.conv_norm_out(hidden_state)
|
437 |
+
hidden_state = self.conv_act(hidden_state)
|
438 |
+
hidden_state = self.conv_out(hidden_state)
|
439 |
+
|
440 |
+
if add_to_output is not None:
|
441 |
+
hidden_state = hidden_state + add_to_output
|
442 |
+
|
443 |
+
eps_hat = hidden_state
|
444 |
+
|
445 |
+
return eps_hat
|
446 |
+
|
447 |
+
@classmethod
|
448 |
+
def load_fp32(cls, device=None, overrides=None):
|
449 |
+
return cls.load("./weights/sdxl_unet.safetensors", device=device, overrides=overrides)
|
450 |
+
|
451 |
+
@classmethod
|
452 |
+
def load_fp16(cls, device=None, overrides=None):
|
453 |
+
return cls.load("./weights/sdxl_unet.fp16.safetensors", device=device, overrides=overrides)
|
454 |
+
|
455 |
+
|
456 |
+
class SDXLControlNet(nn.Module, ModelUtils):
|
457 |
+
def __init__(self):
|
458 |
+
super().__init__()
|
459 |
+
|
460 |
+
# fmt: off
|
461 |
+
|
462 |
+
encoder_hidden_states_dim = 2048
|
463 |
+
|
464 |
+
# timesteps embedding:
|
465 |
+
|
466 |
+
time_sinusoidal_embedding_dim = 320
|
467 |
+
time_embedding_dim = 1280
|
468 |
+
|
469 |
+
self.get_sinusoidal_timestep_embedding = lambda timesteps: get_sinusoidal_embedding(timesteps, time_sinusoidal_embedding_dim)
|
470 |
+
|
471 |
+
self.time_embedding = nn.ModuleDict(dict(
|
472 |
+
linear_1=nn.Linear(time_sinusoidal_embedding_dim, time_embedding_dim),
|
473 |
+
act=nn.SiLU(),
|
474 |
+
linear_2=nn.Linear(time_embedding_dim, time_embedding_dim),
|
475 |
+
))
|
476 |
+
|
477 |
+
# image size and crop coordinates conditioning embedding (i.e. micro conditioning):
|
478 |
+
|
479 |
+
num_micro_conditioning_values = 6
|
480 |
+
micro_conditioning_embedding_dim = 256
|
481 |
+
additional_embedding_encoder_dim = 1280
|
482 |
+
self.get_sinusoidal_micro_conditioning_embedding = lambda micro_conditioning: get_sinusoidal_embedding(micro_conditioning, micro_conditioning_embedding_dim)
|
483 |
+
|
484 |
+
self.add_embedding = nn.ModuleDict(dict(
|
485 |
+
linear_1=nn.Linear(additional_embedding_encoder_dim + num_micro_conditioning_values * micro_conditioning_embedding_dim, time_embedding_dim),
|
486 |
+
act=nn.SiLU(),
|
487 |
+
linear_2=nn.Linear(time_embedding_dim, time_embedding_dim),
|
488 |
+
))
|
489 |
+
|
490 |
+
# controlnet cond embedding:
|
491 |
+
self.controlnet_cond_embedding = nn.ModuleDict(dict(
|
492 |
+
conv_in=nn.Conv2d(3, 16, kernel_size=3, padding=1),
|
493 |
+
blocks=nn.ModuleList([
|
494 |
+
# 16 -> 32
|
495 |
+
nn.Conv2d(16, 16, kernel_size=3, padding=1),
|
496 |
+
nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=2),
|
497 |
+
# 32 -> 96
|
498 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1),
|
499 |
+
nn.Conv2d(32, 96, kernel_size=3, padding=1, stride=2),
|
500 |
+
# 96 -> 256
|
501 |
+
nn.Conv2d(96, 96, kernel_size=3, padding=1),
|
502 |
+
nn.Conv2d(96, 256, kernel_size=3, padding=1, stride=2),
|
503 |
+
]),
|
504 |
+
conv_out=zero_module(nn.Conv2d(256, 320, kernel_size=3, padding=1)),
|
505 |
+
))
|
506 |
+
|
507 |
+
# actual unet blocks:
|
508 |
+
|
509 |
+
self.conv_in = nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
510 |
+
|
511 |
+
self.down_blocks = nn.ModuleList([
|
512 |
+
# 320 -> 320
|
513 |
+
nn.ModuleDict(dict(
|
514 |
+
resnets=nn.ModuleList([
|
515 |
+
ResnetBlock2D(320, 320, time_embedding_dim),
|
516 |
+
ResnetBlock2D(320, 320, time_embedding_dim),
|
517 |
+
]),
|
518 |
+
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)))]),
|
519 |
+
)),
|
520 |
+
# 320 -> 640
|
521 |
+
nn.ModuleDict(dict(
|
522 |
+
resnets=nn.ModuleList([
|
523 |
+
ResnetBlock2D(320, 640, time_embedding_dim),
|
524 |
+
ResnetBlock2D(640, 640, time_embedding_dim),
|
525 |
+
]),
|
526 |
+
attentions=nn.ModuleList([
|
527 |
+
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
|
528 |
+
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
|
529 |
+
]),
|
530 |
+
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)))]),
|
531 |
+
)),
|
532 |
+
# 640 -> 1280
|
533 |
+
nn.ModuleDict(dict(
|
534 |
+
resnets=nn.ModuleList([
|
535 |
+
ResnetBlock2D(640, 1280, time_embedding_dim),
|
536 |
+
ResnetBlock2D(1280, 1280, time_embedding_dim),
|
537 |
+
]),
|
538 |
+
attentions=nn.ModuleList([
|
539 |
+
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
|
540 |
+
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
|
541 |
+
]),
|
542 |
+
)),
|
543 |
+
])
|
544 |
+
|
545 |
+
self.controlnet_down_blocks = nn.ModuleList([
|
546 |
+
zero_module(nn.Conv2d(320, 320, kernel_size=1)),
|
547 |
+
zero_module(nn.Conv2d(320, 320, kernel_size=1)),
|
548 |
+
zero_module(nn.Conv2d(320, 320, kernel_size=1)),
|
549 |
+
zero_module(nn.Conv2d(320, 320, kernel_size=1)),
|
550 |
+
zero_module(nn.Conv2d(640, 640, kernel_size=1)),
|
551 |
+
zero_module(nn.Conv2d(640, 640, kernel_size=1)),
|
552 |
+
zero_module(nn.Conv2d(640, 640, kernel_size=1)),
|
553 |
+
zero_module(nn.Conv2d(1280, 1280, kernel_size=1)),
|
554 |
+
zero_module(nn.Conv2d(1280, 1280, kernel_size=1)),
|
555 |
+
])
|
556 |
+
|
557 |
+
self.mid_block = nn.ModuleDict(dict(
|
558 |
+
resnets=nn.ModuleList([
|
559 |
+
ResnetBlock2D(1280, 1280, time_embedding_dim),
|
560 |
+
ResnetBlock2D(1280, 1280, time_embedding_dim),
|
561 |
+
]),
|
562 |
+
attentions=nn.ModuleList([TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10)]),
|
563 |
+
))
|
564 |
+
|
565 |
+
self.controlnet_mid_block = zero_module(nn.Conv2d(1280, 1280, kernel_size=1))
|
566 |
+
|
567 |
+
# fmt: on
|
568 |
+
|
569 |
+
def forward(
|
570 |
+
self,
|
571 |
+
x_t,
|
572 |
+
t,
|
573 |
+
encoder_hidden_states,
|
574 |
+
micro_conditioning,
|
575 |
+
pooled_encoder_hidden_states,
|
576 |
+
controlnet_cond,
|
577 |
+
):
|
578 |
+
hidden_state = x_t
|
579 |
+
|
580 |
+
t = self.get_sinusoidal_timestep_embedding(t)
|
581 |
+
t = t.to(dtype=hidden_state.dtype)
|
582 |
+
t = self.time_embedding["linear_1"](t)
|
583 |
+
t = self.time_embedding["act"](t)
|
584 |
+
t = self.time_embedding["linear_2"](t)
|
585 |
+
|
586 |
+
additional_conditioning = self.get_sinusoidal_micro_conditioning_embedding(micro_conditioning)
|
587 |
+
additional_conditioning = additional_conditioning.to(dtype=hidden_state.dtype)
|
588 |
+
additional_conditioning = additional_conditioning.flatten(1)
|
589 |
+
additional_conditioning = torch.concat([pooled_encoder_hidden_states, additional_conditioning], dim=-1)
|
590 |
+
additional_conditioning = self.add_embedding["linear_1"](additional_conditioning)
|
591 |
+
additional_conditioning = self.add_embedding["act"](additional_conditioning)
|
592 |
+
additional_conditioning = self.add_embedding["linear_2"](additional_conditioning)
|
593 |
+
|
594 |
+
t = t + additional_conditioning
|
595 |
+
|
596 |
+
controlnet_cond = self.controlnet_cond_embedding["conv_in"](controlnet_cond)
|
597 |
+
controlnet_cond = F.silu(controlnet_cond)
|
598 |
+
|
599 |
+
for block in self.controlnet_cond_embedding["blocks"]:
|
600 |
+
controlnet_cond = F.silu(block(controlnet_cond))
|
601 |
+
|
602 |
+
controlnet_cond = self.controlnet_cond_embedding["conv_out"](controlnet_cond)
|
603 |
+
|
604 |
+
hidden_state = self.conv_in(hidden_state)
|
605 |
+
|
606 |
+
hidden_state = hidden_state + controlnet_cond
|
607 |
+
|
608 |
+
down_block_res_sample = self.controlnet_down_blocks[0](hidden_state)
|
609 |
+
down_block_res_samples = [down_block_res_sample]
|
610 |
+
|
611 |
+
for down_block in self.down_blocks:
|
612 |
+
for i, resnet in enumerate(down_block["resnets"]):
|
613 |
+
hidden_state = resnet(hidden_state, t)
|
614 |
+
|
615 |
+
if "attentions" in down_block:
|
616 |
+
hidden_state = down_block["attentions"][i](hidden_state, encoder_hidden_states)
|
617 |
+
|
618 |
+
down_block_res_sample = self.controlnet_down_blocks[len(down_block_res_samples)](hidden_state)
|
619 |
+
down_block_res_samples.append(down_block_res_sample)
|
620 |
+
|
621 |
+
if "downsamplers" in down_block:
|
622 |
+
hidden_state = down_block["downsamplers"][0]["conv"](hidden_state)
|
623 |
+
|
624 |
+
down_block_res_sample = self.controlnet_down_blocks[len(down_block_res_samples)](hidden_state)
|
625 |
+
down_block_res_samples.append(down_block_res_sample)
|
626 |
+
|
627 |
+
hidden_state = self.mid_block["resnets"][0](hidden_state, t)
|
628 |
+
hidden_state = self.mid_block["attentions"][0](hidden_state, encoder_hidden_states)
|
629 |
+
hidden_state = self.mid_block["resnets"][1](hidden_state, t)
|
630 |
+
|
631 |
+
mid_block_res_sample = self.controlnet_mid_block(hidden_state)
|
632 |
+
|
633 |
+
return dict(
|
634 |
+
down_block_res_samples=down_block_res_samples,
|
635 |
+
mid_block_res_sample=mid_block_res_sample,
|
636 |
+
)
|
637 |
+
|
638 |
+
@classmethod
|
639 |
+
def from_unet(cls, unet):
|
640 |
+
controlnet = cls()
|
641 |
+
|
642 |
+
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
643 |
+
controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
|
644 |
+
|
645 |
+
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
646 |
+
|
647 |
+
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
|
648 |
+
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
649 |
+
|
650 |
+
return controlnet
|
651 |
+
|
652 |
+
|
653 |
+
class SDXLControlNetPreEncodedControlnetCond(nn.Module, ModelUtils):
|
654 |
+
def __init__(self):
|
655 |
+
super().__init__()
|
656 |
+
|
657 |
+
# fmt: off
|
658 |
+
|
659 |
+
encoder_hidden_states_dim = 2048
|
660 |
+
|
661 |
+
# timesteps embedding:
|
662 |
+
|
663 |
+
time_sinusoidal_embedding_dim = 320
|
664 |
+
time_embedding_dim = 1280
|
665 |
+
|
666 |
+
self.get_sinusoidal_timestep_embedding = lambda timesteps: get_sinusoidal_embedding(timesteps, time_sinusoidal_embedding_dim)
|
667 |
+
|
668 |
+
self.time_embedding = nn.ModuleDict(dict(
|
669 |
+
linear_1=nn.Linear(time_sinusoidal_embedding_dim, time_embedding_dim),
|
670 |
+
act=nn.SiLU(),
|
671 |
+
linear_2=nn.Linear(time_embedding_dim, time_embedding_dim),
|
672 |
+
))
|
673 |
+
|
674 |
+
# image size and crop coordinates conditioning embedding (i.e. micro conditioning):
|
675 |
+
|
676 |
+
num_micro_conditioning_values = 6
|
677 |
+
micro_conditioning_embedding_dim = 256
|
678 |
+
additional_embedding_encoder_dim = 1280
|
679 |
+
self.get_sinusoidal_micro_conditioning_embedding = lambda micro_conditioning: get_sinusoidal_embedding(micro_conditioning, micro_conditioning_embedding_dim)
|
680 |
+
|
681 |
+
self.add_embedding = nn.ModuleDict(dict(
|
682 |
+
linear_1=nn.Linear(additional_embedding_encoder_dim + num_micro_conditioning_values * micro_conditioning_embedding_dim, time_embedding_dim),
|
683 |
+
act=nn.SiLU(),
|
684 |
+
linear_2=nn.Linear(time_embedding_dim, time_embedding_dim),
|
685 |
+
))
|
686 |
+
|
687 |
+
# actual unet blocks:
|
688 |
+
|
689 |
+
# unet latents: 4 +
|
690 |
+
# control image latents: 4 +
|
691 |
+
# controlnet_mask: 1
|
692 |
+
# = 9 channels
|
693 |
+
self.conv_in = nn.Conv2d(9, 320, kernel_size=3, padding=1)
|
694 |
+
|
695 |
+
self.down_blocks = nn.ModuleList([
|
696 |
+
# 320 -> 320
|
697 |
+
nn.ModuleDict(dict(
|
698 |
+
resnets=nn.ModuleList([
|
699 |
+
ResnetBlock2D(320, 320, time_embedding_dim),
|
700 |
+
ResnetBlock2D(320, 320, time_embedding_dim),
|
701 |
+
]),
|
702 |
+
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)))]),
|
703 |
+
)),
|
704 |
+
# 320 -> 640
|
705 |
+
nn.ModuleDict(dict(
|
706 |
+
resnets=nn.ModuleList([
|
707 |
+
ResnetBlock2D(320, 640, time_embedding_dim),
|
708 |
+
ResnetBlock2D(640, 640, time_embedding_dim),
|
709 |
+
]),
|
710 |
+
attentions=nn.ModuleList([
|
711 |
+
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
|
712 |
+
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
|
713 |
+
]),
|
714 |
+
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)))]),
|
715 |
+
)),
|
716 |
+
# 640 -> 1280
|
717 |
+
nn.ModuleDict(dict(
|
718 |
+
resnets=nn.ModuleList([
|
719 |
+
ResnetBlock2D(640, 1280, time_embedding_dim),
|
720 |
+
ResnetBlock2D(1280, 1280, time_embedding_dim),
|
721 |
+
]),
|
722 |
+
attentions=nn.ModuleList([
|
723 |
+
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
|
724 |
+
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
|
725 |
+
]),
|
726 |
+
)),
|
727 |
+
])
|
728 |
+
|
729 |
+
self.controlnet_down_blocks = nn.ModuleList([
|
730 |
+
zero_module(nn.Conv2d(320, 320, kernel_size=1)),
|
731 |
+
zero_module(nn.Conv2d(320, 320, kernel_size=1)),
|
732 |
+
zero_module(nn.Conv2d(320, 320, kernel_size=1)),
|
733 |
+
zero_module(nn.Conv2d(320, 320, kernel_size=1)),
|
734 |
+
zero_module(nn.Conv2d(640, 640, kernel_size=1)),
|
735 |
+
zero_module(nn.Conv2d(640, 640, kernel_size=1)),
|
736 |
+
zero_module(nn.Conv2d(640, 640, kernel_size=1)),
|
737 |
+
zero_module(nn.Conv2d(1280, 1280, kernel_size=1)),
|
738 |
+
zero_module(nn.Conv2d(1280, 1280, kernel_size=1)),
|
739 |
+
])
|
740 |
+
|
741 |
+
self.mid_block = nn.ModuleDict(dict(
|
742 |
+
resnets=nn.ModuleList([
|
743 |
+
ResnetBlock2D(1280, 1280, time_embedding_dim),
|
744 |
+
ResnetBlock2D(1280, 1280, time_embedding_dim),
|
745 |
+
]),
|
746 |
+
attentions=nn.ModuleList([TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10)]),
|
747 |
+
))
|
748 |
+
|
749 |
+
self.controlnet_mid_block = zero_module(nn.Conv2d(1280, 1280, kernel_size=1))
|
750 |
+
|
751 |
+
# fmt: on
|
752 |
+
|
753 |
+
def forward(
|
754 |
+
self,
|
755 |
+
x_t,
|
756 |
+
t,
|
757 |
+
encoder_hidden_states,
|
758 |
+
micro_conditioning,
|
759 |
+
pooled_encoder_hidden_states,
|
760 |
+
controlnet_cond,
|
761 |
+
):
|
762 |
+
hidden_state = x_t
|
763 |
+
|
764 |
+
t = self.get_sinusoidal_timestep_embedding(t)
|
765 |
+
t = t.to(dtype=hidden_state.dtype)
|
766 |
+
t = self.time_embedding["linear_1"](t)
|
767 |
+
t = self.time_embedding["act"](t)
|
768 |
+
t = self.time_embedding["linear_2"](t)
|
769 |
+
|
770 |
+
additional_conditioning = self.get_sinusoidal_micro_conditioning_embedding(micro_conditioning)
|
771 |
+
additional_conditioning = additional_conditioning.to(dtype=hidden_state.dtype)
|
772 |
+
additional_conditioning = additional_conditioning.flatten(1)
|
773 |
+
additional_conditioning = torch.concat([pooled_encoder_hidden_states, additional_conditioning], dim=-1)
|
774 |
+
additional_conditioning = self.add_embedding["linear_1"](additional_conditioning)
|
775 |
+
additional_conditioning = self.add_embedding["act"](additional_conditioning)
|
776 |
+
additional_conditioning = self.add_embedding["linear_2"](additional_conditioning)
|
777 |
+
|
778 |
+
t = t + additional_conditioning
|
779 |
+
|
780 |
+
hidden_state = torch.concat((hidden_state, controlnet_cond), dim=1)
|
781 |
+
|
782 |
+
hidden_state = self.conv_in(hidden_state)
|
783 |
+
|
784 |
+
down_block_res_sample = self.controlnet_down_blocks[0](hidden_state)
|
785 |
+
down_block_res_samples = [down_block_res_sample]
|
786 |
+
|
787 |
+
for down_block in self.down_blocks:
|
788 |
+
for i, resnet in enumerate(down_block["resnets"]):
|
789 |
+
hidden_state = resnet(hidden_state, t)
|
790 |
+
|
791 |
+
if "attentions" in down_block:
|
792 |
+
hidden_state = down_block["attentions"][i](hidden_state, encoder_hidden_states)
|
793 |
+
|
794 |
+
down_block_res_sample = self.controlnet_down_blocks[len(down_block_res_samples)](hidden_state)
|
795 |
+
down_block_res_samples.append(down_block_res_sample)
|
796 |
+
|
797 |
+
if "downsamplers" in down_block:
|
798 |
+
hidden_state = down_block["downsamplers"][0]["conv"](hidden_state)
|
799 |
+
|
800 |
+
down_block_res_sample = self.controlnet_down_blocks[len(down_block_res_samples)](hidden_state)
|
801 |
+
down_block_res_samples.append(down_block_res_sample)
|
802 |
+
|
803 |
+
hidden_state = self.mid_block["resnets"][0](hidden_state, t)
|
804 |
+
hidden_state = self.mid_block["attentions"][0](hidden_state, encoder_hidden_states)
|
805 |
+
hidden_state = self.mid_block["resnets"][1](hidden_state, t)
|
806 |
+
|
807 |
+
mid_block_res_sample = self.controlnet_mid_block(hidden_state)
|
808 |
+
|
809 |
+
return dict(
|
810 |
+
down_block_res_samples=down_block_res_samples,
|
811 |
+
mid_block_res_sample=mid_block_res_sample,
|
812 |
+
)
|
813 |
+
|
814 |
+
@classmethod
|
815 |
+
def from_unet(cls, unet):
|
816 |
+
controlnet = cls()
|
817 |
+
|
818 |
+
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
819 |
+
controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
|
820 |
+
|
821 |
+
conv_in_weight = unet.conv_in.state_dict()["weight"]
|
822 |
+
padding = torch.zeros((320, 5, 3, 3), device=conv_in_weight.device, dtype=conv_in_weight.dtype)
|
823 |
+
conv_in_weight = torch.concat((conv_in_weight, padding), dim=1)
|
824 |
+
|
825 |
+
conv_in_bias = unet.conv_in.state_dict()["bias"]
|
826 |
+
|
827 |
+
controlnet.conv_in.load_state_dict({"weight": conv_in_weight, "bias": conv_in_bias})
|
828 |
+
|
829 |
+
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
|
830 |
+
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
831 |
+
|
832 |
+
return controlnet
|
833 |
+
|
834 |
+
|
835 |
+
class SDXLControlNetFull(nn.Module, ModelUtils):
|
836 |
+
def __init__(self):
|
837 |
+
super().__init__()
|
838 |
+
|
839 |
+
# fmt: off
|
840 |
+
|
841 |
+
encoder_hidden_states_dim = 2048
|
842 |
+
|
843 |
+
# timesteps embedding:
|
844 |
+
|
845 |
+
time_sinusoidal_embedding_dim = 320
|
846 |
+
time_embedding_dim = 1280
|
847 |
+
|
848 |
+
self.get_sinusoidal_timestep_embedding = lambda timesteps: get_sinusoidal_embedding(timesteps, time_sinusoidal_embedding_dim)
|
849 |
+
|
850 |
+
self.time_embedding = nn.ModuleDict(dict(
|
851 |
+
linear_1=nn.Linear(time_sinusoidal_embedding_dim, time_embedding_dim),
|
852 |
+
act=nn.SiLU(),
|
853 |
+
linear_2=nn.Linear(time_embedding_dim, time_embedding_dim),
|
854 |
+
))
|
855 |
+
|
856 |
+
# image size and crop coordinates conditioning embedding (i.e. micro conditioning):
|
857 |
+
|
858 |
+
num_micro_conditioning_values = 6
|
859 |
+
micro_conditioning_embedding_dim = 256
|
860 |
+
additional_embedding_encoder_dim = 1280
|
861 |
+
self.get_sinusoidal_micro_conditioning_embedding = lambda micro_conditioning: get_sinusoidal_embedding(micro_conditioning, micro_conditioning_embedding_dim)
|
862 |
+
|
863 |
+
self.add_embedding = nn.ModuleDict(dict(
|
864 |
+
linear_1=nn.Linear(additional_embedding_encoder_dim + num_micro_conditioning_values * micro_conditioning_embedding_dim, time_embedding_dim),
|
865 |
+
act=nn.SiLU(),
|
866 |
+
linear_2=nn.Linear(time_embedding_dim, time_embedding_dim),
|
867 |
+
))
|
868 |
+
|
869 |
+
# controlnet cond embedding:
|
870 |
+
self.controlnet_cond_embedding = nn.ModuleDict(dict(
|
871 |
+
conv_in=nn.Conv2d(3, 16, kernel_size=3, padding=1),
|
872 |
+
blocks=nn.ModuleList([
|
873 |
+
# 16 -> 32
|
874 |
+
nn.Conv2d(16, 16, kernel_size=3, padding=1),
|
875 |
+
nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=2),
|
876 |
+
# 32 -> 96
|
877 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1),
|
878 |
+
nn.Conv2d(32, 96, kernel_size=3, padding=1, stride=2),
|
879 |
+
# 96 -> 256
|
880 |
+
nn.Conv2d(96, 96, kernel_size=3, padding=1),
|
881 |
+
nn.Conv2d(96, 256, kernel_size=3, padding=1, stride=2),
|
882 |
+
]),
|
883 |
+
conv_out=zero_module(nn.Conv2d(256, 320, kernel_size=3, padding=1)),
|
884 |
+
))
|
885 |
+
|
886 |
+
# actual unet blocks:
|
887 |
+
|
888 |
+
self.conv_in = nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
889 |
+
|
890 |
+
self.down_blocks = nn.ModuleList([
|
891 |
+
# 320 -> 320
|
892 |
+
nn.ModuleDict(dict(
|
893 |
+
resnets=nn.ModuleList([
|
894 |
+
ResnetBlock2D(320, 320, time_embedding_dim),
|
895 |
+
ResnetBlock2D(320, 320, time_embedding_dim),
|
896 |
+
]),
|
897 |
+
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)))]),
|
898 |
+
)),
|
899 |
+
# 320 -> 640
|
900 |
+
nn.ModuleDict(dict(
|
901 |
+
resnets=nn.ModuleList([
|
902 |
+
ResnetBlock2D(320, 640, time_embedding_dim),
|
903 |
+
ResnetBlock2D(640, 640, time_embedding_dim),
|
904 |
+
]),
|
905 |
+
attentions=nn.ModuleList([
|
906 |
+
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
|
907 |
+
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
|
908 |
+
]),
|
909 |
+
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)))]),
|
910 |
+
)),
|
911 |
+
# 640 -> 1280
|
912 |
+
nn.ModuleDict(dict(
|
913 |
+
resnets=nn.ModuleList([
|
914 |
+
ResnetBlock2D(640, 1280, time_embedding_dim),
|
915 |
+
ResnetBlock2D(1280, 1280, time_embedding_dim),
|
916 |
+
]),
|
917 |
+
attentions=nn.ModuleList([
|
918 |
+
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
|
919 |
+
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
|
920 |
+
]),
|
921 |
+
)),
|
922 |
+
])
|
923 |
+
|
924 |
+
self.controlnet_down_blocks = nn.ModuleList([
|
925 |
+
zero_module(nn.Conv2d(320, 320, kernel_size=1)),
|
926 |
+
zero_module(nn.Conv2d(320, 320, kernel_size=1)),
|
927 |
+
zero_module(nn.Conv2d(320, 320, kernel_size=1)),
|
928 |
+
zero_module(nn.Conv2d(320, 320, kernel_size=1)),
|
929 |
+
zero_module(nn.Conv2d(640, 640, kernel_size=1)),
|
930 |
+
zero_module(nn.Conv2d(640, 640, kernel_size=1)),
|
931 |
+
zero_module(nn.Conv2d(640, 640, kernel_size=1)),
|
932 |
+
zero_module(nn.Conv2d(1280, 1280, kernel_size=1)),
|
933 |
+
])
|
934 |
+
|
935 |
+
self.mid_block = nn.ModuleDict(dict(
|
936 |
+
resnets=nn.ModuleList([
|
937 |
+
ResnetBlock2D(1280, 1280, time_embedding_dim),
|
938 |
+
ResnetBlock2D(1280, 1280, time_embedding_dim),
|
939 |
+
]),
|
940 |
+
attentions=nn.ModuleList([TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10)]),
|
941 |
+
))
|
942 |
+
|
943 |
+
self.controlnet_mid_block = zero_module(nn.Conv2d(1280, 1280, kernel_size=1))
|
944 |
+
|
945 |
+
self.up_blocks = nn.ModuleList([
|
946 |
+
# 1280 -> 1280
|
947 |
+
nn.ModuleDict(dict(
|
948 |
+
resnets=nn.ModuleList([
|
949 |
+
ResnetBlock2D(1280 + 1280, 1280, time_embedding_dim),
|
950 |
+
ResnetBlock2D(1280 + 1280, 1280, time_embedding_dim),
|
951 |
+
ResnetBlock2D(1280 + 640, 1280, time_embedding_dim),
|
952 |
+
]),
|
953 |
+
attentions=nn.ModuleList([
|
954 |
+
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
|
955 |
+
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
|
956 |
+
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10),
|
957 |
+
]),
|
958 |
+
upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(1280, 1280, kernel_size=3, padding=1)))]),
|
959 |
+
)),
|
960 |
+
# 1280 -> 640
|
961 |
+
nn.ModuleDict(dict(
|
962 |
+
resnets=nn.ModuleList([
|
963 |
+
ResnetBlock2D(1280 + 640, 640, time_embedding_dim),
|
964 |
+
ResnetBlock2D(640 + 640, 640, time_embedding_dim),
|
965 |
+
ResnetBlock2D(640 + 320, 640, time_embedding_dim),
|
966 |
+
]),
|
967 |
+
attentions=nn.ModuleList([
|
968 |
+
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
|
969 |
+
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
|
970 |
+
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2),
|
971 |
+
]),
|
972 |
+
upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, padding=1)))]),
|
973 |
+
)),
|
974 |
+
# 640 -> 320
|
975 |
+
nn.ModuleDict(dict(
|
976 |
+
resnets=nn.ModuleList([
|
977 |
+
ResnetBlock2D(640 + 320, 320, time_embedding_dim),
|
978 |
+
ResnetBlock2D(320 + 320, 320, time_embedding_dim),
|
979 |
+
ResnetBlock2D(320 + 320, 320, time_embedding_dim),
|
980 |
+
]),
|
981 |
+
))
|
982 |
+
])
|
983 |
+
|
984 |
+
# take the output of transformer(resnet(hidden_states)) and project it to
|
985 |
+
# the number of residual channels for the same block
|
986 |
+
self.controlnet_up_blocks = nn.ModuleList([
|
987 |
+
zero_module(nn.Conv2d(1280, 1280, kernel_size=1)),
|
988 |
+
zero_module(nn.Conv2d(1280, 1280, kernel_size=1)),
|
989 |
+
zero_module(nn.Conv2d(1280, 640, kernel_size=1)),
|
990 |
+
zero_module(nn.Conv2d(640, 640, kernel_size=1)),
|
991 |
+
zero_module(nn.Conv2d(640, 640, kernel_size=1)),
|
992 |
+
zero_module(nn.Conv2d(640, 320, kernel_size=1)),
|
993 |
+
zero_module(nn.Conv2d(320, 320, kernel_size=1)),
|
994 |
+
zero_module(nn.Conv2d(320, 320, kernel_size=1)),
|
995 |
+
zero_module(nn.Conv2d(320, 320, kernel_size=1)),
|
996 |
+
])
|
997 |
+
|
998 |
+
self.conv_norm_out = nn.GroupNorm(32, 320)
|
999 |
+
self.conv_act = nn.SiLU()
|
1000 |
+
self.conv_out = nn.Conv2d(320, 4, kernel_size=3, padding=1)
|
1001 |
+
|
1002 |
+
self.controlnet_conv_out = zero_module(nn.Conv2d(4, 4, kernel_size=1))
|
1003 |
+
|
1004 |
+
# fmt: on
|
1005 |
+
|
1006 |
+
def forward(
|
1007 |
+
self,
|
1008 |
+
x_t,
|
1009 |
+
t,
|
1010 |
+
encoder_hidden_states,
|
1011 |
+
micro_conditioning,
|
1012 |
+
pooled_encoder_hidden_states,
|
1013 |
+
controlnet_cond,
|
1014 |
+
):
|
1015 |
+
hidden_state = x_t
|
1016 |
+
|
1017 |
+
t = self.get_sinusoidal_timestep_embedding(t)
|
1018 |
+
t = t.to(dtype=hidden_state.dtype)
|
1019 |
+
t = self.time_embedding["linear_1"](t)
|
1020 |
+
t = self.time_embedding["act"](t)
|
1021 |
+
t = self.time_embedding["linear_2"](t)
|
1022 |
+
|
1023 |
+
additional_conditioning = self.get_sinusoidal_micro_conditioning_embedding(micro_conditioning)
|
1024 |
+
additional_conditioning = additional_conditioning.to(dtype=hidden_state.dtype)
|
1025 |
+
additional_conditioning = additional_conditioning.flatten(1)
|
1026 |
+
additional_conditioning = torch.concat([pooled_encoder_hidden_states, additional_conditioning], dim=-1)
|
1027 |
+
additional_conditioning = self.add_embedding["linear_1"](additional_conditioning)
|
1028 |
+
additional_conditioning = self.add_embedding["act"](additional_conditioning)
|
1029 |
+
additional_conditioning = self.add_embedding["linear_2"](additional_conditioning)
|
1030 |
+
|
1031 |
+
t = t + additional_conditioning
|
1032 |
+
|
1033 |
+
controlnet_cond = self.controlnet_cond_embedding["conv_in"](controlnet_cond)
|
1034 |
+
controlnet_cond = F.silu(controlnet_cond)
|
1035 |
+
|
1036 |
+
for block in self.controlnet_cond_embedding["blocks"]:
|
1037 |
+
controlnet_cond = F.silu(block(controlnet_cond))
|
1038 |
+
|
1039 |
+
controlnet_cond = self.controlnet_cond_embedding["conv_out"](controlnet_cond)
|
1040 |
+
|
1041 |
+
hidden_state = self.conv_in(hidden_state)
|
1042 |
+
|
1043 |
+
hidden_state = hidden_state + controlnet_cond
|
1044 |
+
|
1045 |
+
residuals = [hidden_state]
|
1046 |
+
|
1047 |
+
add_to_down_block_input = self.controlnet_down_blocks[0](hidden_state)
|
1048 |
+
add_to_down_block_inputs = [add_to_down_block_input]
|
1049 |
+
|
1050 |
+
for down_block in self.down_blocks:
|
1051 |
+
for i, resnet in enumerate(down_block["resnets"]):
|
1052 |
+
hidden_state = resnet(hidden_state, t)
|
1053 |
+
|
1054 |
+
if "attentions" in down_block:
|
1055 |
+
hidden_state = down_block["attentions"][i](hidden_state, encoder_hidden_states)
|
1056 |
+
|
1057 |
+
if len(add_to_down_block_inputs) < len(self.controlnet_down_blocks):
|
1058 |
+
add_to_down_block_input = self.controlnet_down_blocks[len(add_to_down_block_inputs)](hidden_state)
|
1059 |
+
add_to_down_block_inputs.append(add_to_down_block_input)
|
1060 |
+
|
1061 |
+
residuals.append(hidden_state)
|
1062 |
+
|
1063 |
+
if "downsamplers" in down_block:
|
1064 |
+
hidden_state = down_block["downsamplers"][0]["conv"](hidden_state)
|
1065 |
+
|
1066 |
+
if len(add_to_down_block_inputs) < len(self.controlnet_down_blocks):
|
1067 |
+
add_to_down_block_input = self.controlnet_down_blocks[len(add_to_down_block_inputs)](hidden_state)
|
1068 |
+
add_to_down_block_inputs.append(add_to_down_block_input)
|
1069 |
+
|
1070 |
+
residuals.append(hidden_state)
|
1071 |
+
|
1072 |
+
hidden_state = self.mid_block["resnets"][0](hidden_state, t)
|
1073 |
+
hidden_state = self.mid_block["attentions"][0](hidden_state, encoder_hidden_states)
|
1074 |
+
hidden_state = self.mid_block["resnets"][1](hidden_state, t)
|
1075 |
+
|
1076 |
+
mid_block_res_sample = self.controlnet_mid_block(hidden_state)
|
1077 |
+
|
1078 |
+
down_block_res_samples = []
|
1079 |
+
|
1080 |
+
for up_block in self.up_blocks:
|
1081 |
+
for i, resnet in enumerate(up_block["resnets"]):
|
1082 |
+
residual = residuals.pop()
|
1083 |
+
|
1084 |
+
hidden_state = torch.concat([hidden_state, residual], dim=1)
|
1085 |
+
|
1086 |
+
hidden_state = resnet(hidden_state, t)
|
1087 |
+
|
1088 |
+
if "attentions" in up_block:
|
1089 |
+
hidden_state = up_block["attentions"][i](hidden_state, encoder_hidden_states)
|
1090 |
+
|
1091 |
+
down_block_res_sample = self.controlnet_up_blocks[len(down_block_res_samples)](hidden_state)
|
1092 |
+
down_block_res_samples.insert(0, down_block_res_sample)
|
1093 |
+
|
1094 |
+
if "upsamplers" in up_block:
|
1095 |
+
hidden_state = F.interpolate(hidden_state, scale_factor=2.0, mode="nearest")
|
1096 |
+
hidden_state = up_block["upsamplers"][0]["conv"](hidden_state)
|
1097 |
+
|
1098 |
+
hidden_state = self.conv_norm_out(hidden_state)
|
1099 |
+
hidden_state = self.conv_act(hidden_state)
|
1100 |
+
hidden_state = self.conv_out(hidden_state)
|
1101 |
+
|
1102 |
+
add_to_output = self.controlnet_conv_out(hidden_state)
|
1103 |
+
|
1104 |
+
return dict(
|
1105 |
+
down_block_res_samples=down_block_res_samples,
|
1106 |
+
mid_block_res_sample=mid_block_res_sample,
|
1107 |
+
add_to_down_block_inputs=add_to_down_block_inputs,
|
1108 |
+
add_to_output=add_to_output,
|
1109 |
+
)
|
1110 |
+
|
1111 |
+
@classmethod
|
1112 |
+
def from_unet(cls, unet):
|
1113 |
+
controlnet = cls()
|
1114 |
+
|
1115 |
+
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
1116 |
+
controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
|
1117 |
+
|
1118 |
+
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
1119 |
+
|
1120 |
+
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
|
1121 |
+
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
1122 |
+
controlnet.up_blocks.load_state_dict(unet.up_blocks.state_dict())
|
1123 |
+
|
1124 |
+
controlnet.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict())
|
1125 |
+
controlnet.conv_out.load_state_dict(unet.conv_out.state_dict())
|
1126 |
+
|
1127 |
+
return controlnet
|
1128 |
+
|
1129 |
+
|
1130 |
+
class SDXLAdapter(nn.Module, ModelUtils):
|
1131 |
+
def __init__(self):
|
1132 |
+
super().__init__()
|
1133 |
+
|
1134 |
+
# fmt: off
|
1135 |
+
|
1136 |
+
self.adapter = nn.ModuleDict(dict(
|
1137 |
+
# 3 -> 768
|
1138 |
+
unshuffle=nn.PixelUnshuffle(16),
|
1139 |
+
|
1140 |
+
# 768 -> 320
|
1141 |
+
conv_in=nn.Conv2d(768, 320, kernel_size=3, padding=1),
|
1142 |
+
|
1143 |
+
body=nn.ModuleList([
|
1144 |
+
# 320 -> 320
|
1145 |
+
nn.ModuleDict(dict(
|
1146 |
+
resnets=nn.ModuleList(
|
1147 |
+
nn.ModuleDict(dict(block1=nn.Conv2d(320, 320, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(320, 320, kernel_size=1))),
|
1148 |
+
nn.ModuleDict(dict(block1=nn.Conv2d(320, 320, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(320, 320, kernel_size=1))),
|
1149 |
+
)
|
1150 |
+
)),
|
1151 |
+
# 320 -> 640
|
1152 |
+
nn.ModuleDict(dict(
|
1153 |
+
in_conv=nn.Conv2d(320, 640, kernel_size=1),
|
1154 |
+
resnets=nn.ModuleList(
|
1155 |
+
nn.ModuleDict(dict(block1=nn.Conv2d(640, 640, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(640, 640, kernel_size=1))),
|
1156 |
+
nn.ModuleDict(dict(block1=nn.Conv2d(640, 640, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(640, 640, kernel_size=1))),
|
1157 |
+
)
|
1158 |
+
)),
|
1159 |
+
# 640 -> 1280
|
1160 |
+
nn.ModuleDict(dict(
|
1161 |
+
downsample=nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
|
1162 |
+
in_conv=nn.Conv2d(640, 1280, kernel_size=1),
|
1163 |
+
resnets=nn.ModuleList(
|
1164 |
+
nn.ModuleDict(dict(block1=nn.Conv2d(1280, 1280, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(1280, 1280, kernel_size=1))),
|
1165 |
+
nn.ModuleDict(dict(block1=nn.Conv2d(1280, 1280, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(1280, 1280, kernel_size=1))),
|
1166 |
+
)
|
1167 |
+
)),
|
1168 |
+
# 1280 -> 1280
|
1169 |
+
nn.ModuleDict(dict(
|
1170 |
+
resnets=nn.ModuleList(
|
1171 |
+
nn.ModuleDict(dict(block1=nn.Conv2d(1280, 1280, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(1280, 1280, kernel_size=1))),
|
1172 |
+
nn.ModuleDict(dict(block1=nn.Conv2d(1280, 1280, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(1280, 1280, kernel_size=1))),
|
1173 |
+
)
|
1174 |
+
)),
|
1175 |
+
])
|
1176 |
+
))
|
1177 |
+
|
1178 |
+
# fmt: on
|
1179 |
+
|
1180 |
+
def forward(self, x):
|
1181 |
+
x = self.unshuffle(x)
|
1182 |
+
x = self.conv_in(x)
|
1183 |
+
|
1184 |
+
features = []
|
1185 |
+
|
1186 |
+
for block in self.body:
|
1187 |
+
if "downsample" in block:
|
1188 |
+
x = block["downsample"](x)
|
1189 |
+
|
1190 |
+
if "in_conv" in block:
|
1191 |
+
x = block["in_conv"](x)
|
1192 |
+
|
1193 |
+
for resnet in block["resnets"]:
|
1194 |
+
residual = x
|
1195 |
+
x = resnet["block1"](x)
|
1196 |
+
x = resnet["act"](x)
|
1197 |
+
x = resnet["block2"](x)
|
1198 |
+
x = residual + x
|
1199 |
+
|
1200 |
+
features.append(x)
|
1201 |
+
|
1202 |
+
return features
|
1203 |
+
|
1204 |
+
|
1205 |
+
def get_sinusoidal_embedding(
|
1206 |
+
indices: torch.Tensor,
|
1207 |
+
embedding_dim: int,
|
1208 |
+
):
|
1209 |
+
half_dim = embedding_dim // 2
|
1210 |
+
exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=indices.device)
|
1211 |
+
exponent = exponent / half_dim
|
1212 |
+
|
1213 |
+
emb = torch.exp(exponent)
|
1214 |
+
emb = indices.unsqueeze(-1).float() * emb
|
1215 |
+
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
|
1216 |
+
|
1217 |
+
return emb
|
1218 |
+
|
1219 |
+
|
1220 |
+
class ResnetBlock2D(nn.Module):
|
1221 |
+
def __init__(self, in_channels, out_channels, time_embedding_dim=None, eps=1e-5):
|
1222 |
+
super().__init__()
|
1223 |
+
|
1224 |
+
if time_embedding_dim is not None:
|
1225 |
+
self.time_emb_proj = nn.Linear(time_embedding_dim, out_channels)
|
1226 |
+
else:
|
1227 |
+
self.time_emb_proj = None
|
1228 |
+
|
1229 |
+
self.norm1 = torch.nn.GroupNorm(32, in_channels, eps=eps)
|
1230 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
1231 |
+
|
1232 |
+
self.norm2 = nn.GroupNorm(32, out_channels, eps=eps)
|
1233 |
+
self.dropout = nn.Dropout(0.0)
|
1234 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
1235 |
+
|
1236 |
+
self.nonlinearity = nn.SiLU()
|
1237 |
+
|
1238 |
+
if in_channels != out_channels:
|
1239 |
+
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
1240 |
+
else:
|
1241 |
+
self.conv_shortcut = None
|
1242 |
+
|
1243 |
+
def forward(self, hidden_states, temb=None):
|
1244 |
+
residual = hidden_states
|
1245 |
+
|
1246 |
+
if self.time_emb_proj is not None:
|
1247 |
+
assert temb is not None
|
1248 |
+
temb = self.nonlinearity(temb)
|
1249 |
+
temb = self.time_emb_proj(temb)[:, :, None, None]
|
1250 |
+
|
1251 |
+
hidden_states = self.norm1(hidden_states)
|
1252 |
+
hidden_states = self.nonlinearity(hidden_states)
|
1253 |
+
hidden_states = self.conv1(hidden_states)
|
1254 |
+
|
1255 |
+
if temb is not None:
|
1256 |
+
hidden_states = hidden_states + temb
|
1257 |
+
|
1258 |
+
hidden_states = self.norm2(hidden_states)
|
1259 |
+
hidden_states = self.nonlinearity(hidden_states)
|
1260 |
+
hidden_states = self.dropout(hidden_states)
|
1261 |
+
hidden_states = self.conv2(hidden_states)
|
1262 |
+
|
1263 |
+
if self.conv_shortcut is not None:
|
1264 |
+
residual = self.conv_shortcut(residual)
|
1265 |
+
|
1266 |
+
hidden_states = hidden_states + residual
|
1267 |
+
|
1268 |
+
return hidden_states
|
1269 |
+
|
1270 |
+
|
1271 |
+
class TransformerDecoder2D(nn.Module):
|
1272 |
+
def __init__(self, channels, encoder_hidden_states_dim, num_transformer_blocks):
|
1273 |
+
super().__init__()
|
1274 |
+
|
1275 |
+
self.norm = nn.GroupNorm(32, channels, eps=1e-06)
|
1276 |
+
self.proj_in = nn.Linear(channels, channels)
|
1277 |
+
|
1278 |
+
self.transformer_blocks = nn.ModuleList([TransformerDecoderBlock(channels, encoder_hidden_states_dim) for _ in range(num_transformer_blocks)])
|
1279 |
+
|
1280 |
+
self.proj_out = nn.Linear(channels, channels)
|
1281 |
+
|
1282 |
+
def forward(self, hidden_states, encoder_hidden_states):
|
1283 |
+
batch_size, channels, height, width = hidden_states.shape
|
1284 |
+
|
1285 |
+
residual = hidden_states
|
1286 |
+
|
1287 |
+
hidden_states = self.norm(hidden_states)
|
1288 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
|
1289 |
+
hidden_states = self.proj_in(hidden_states)
|
1290 |
+
|
1291 |
+
for block in self.transformer_blocks:
|
1292 |
+
hidden_states = block(hidden_states, encoder_hidden_states)
|
1293 |
+
|
1294 |
+
hidden_states = self.proj_out(hidden_states)
|
1295 |
+
hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2).contiguous()
|
1296 |
+
|
1297 |
+
hidden_states = hidden_states + residual
|
1298 |
+
|
1299 |
+
return hidden_states
|
1300 |
+
|
1301 |
+
|
1302 |
+
class TransformerDecoderBlock(nn.Module):
|
1303 |
+
def __init__(self, channels, encoder_hidden_states_dim):
|
1304 |
+
super().__init__()
|
1305 |
+
|
1306 |
+
self.norm1 = nn.LayerNorm(channels)
|
1307 |
+
self.attn1 = Attention(channels, channels)
|
1308 |
+
|
1309 |
+
self.norm2 = nn.LayerNorm(channels)
|
1310 |
+
self.attn2 = Attention(channels, encoder_hidden_states_dim)
|
1311 |
+
|
1312 |
+
self.norm3 = nn.LayerNorm(channels)
|
1313 |
+
self.ff = nn.ModuleDict(dict(net=nn.Sequential(GEGLU(channels, 4 * channels), nn.Dropout(0.0), nn.Linear(4 * channels, channels))))
|
1314 |
+
|
1315 |
+
def forward(self, hidden_states, encoder_hidden_states):
|
1316 |
+
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
|
1317 |
+
|
1318 |
+
hidden_states = self.attn2(self.norm2(hidden_states), encoder_hidden_states) + hidden_states
|
1319 |
+
|
1320 |
+
hidden_states = self.ff["net"](self.norm3(hidden_states)) + hidden_states
|
1321 |
+
|
1322 |
+
return hidden_states
|
1323 |
+
|
1324 |
+
|
1325 |
+
class Attention(nn.Module):
|
1326 |
+
def __init__(self, channels, encoder_hidden_states_dim, qkv_bias=False):
|
1327 |
+
super().__init__()
|
1328 |
+
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
|
1329 |
+
self.to_k = nn.Linear(encoder_hidden_states_dim, channels, bias=qkv_bias)
|
1330 |
+
self.to_v = nn.Linear(encoder_hidden_states_dim, channels, bias=qkv_bias)
|
1331 |
+
self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
|
1332 |
+
|
1333 |
+
def forward(self, hidden_states, encoder_hidden_states=None):
|
1334 |
+
batch_size, q_seq_len, channels = hidden_states.shape
|
1335 |
+
head_dim = 64
|
1336 |
+
|
1337 |
+
if encoder_hidden_states is not None:
|
1338 |
+
kv = encoder_hidden_states
|
1339 |
+
else:
|
1340 |
+
kv = hidden_states
|
1341 |
+
|
1342 |
+
kv_seq_len = kv.shape[1]
|
1343 |
+
|
1344 |
+
query = self.to_q(hidden_states)
|
1345 |
+
key = self.to_k(kv)
|
1346 |
+
value = self.to_v(kv)
|
1347 |
+
|
1348 |
+
query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).contiguous()
|
1349 |
+
key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
|
1350 |
+
value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous()
|
1351 |
+
|
1352 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value)
|
1353 |
+
|
1354 |
+
hidden_states = hidden_states.to(query.dtype)
|
1355 |
+
hidden_states = hidden_states.reshape(batch_size, q_seq_len, channels).contiguous()
|
1356 |
+
|
1357 |
+
hidden_states = self.to_out(hidden_states)
|
1358 |
+
|
1359 |
+
return hidden_states
|
1360 |
+
|
1361 |
+
|
1362 |
+
class GEGLU(nn.Module):
|
1363 |
+
def __init__(self, dim_in: int, dim_out: int):
|
1364 |
+
super().__init__()
|
1365 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
1366 |
+
|
1367 |
+
def forward(self, hidden_states):
|
1368 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
1369 |
+
return hidden_states * F.gelu(gate)
|
1370 |
+
|
1371 |
+
|
1372 |
+
def zero_module(module):
|
1373 |
+
for p in module.parameters():
|
1374 |
+
nn.init.zeros_(p)
|
1375 |
+
return module
|