OzzyGT HF staff commited on
Commit
58f8532
1 Parent(s): 29f18d8

first poc version

Browse files
Files changed (6) hide show
  1. .gitignore +3 -0
  2. README.md +2 -2
  3. app.py +148 -0
  4. controlnet_union.py +1090 -0
  5. pipeline_sdxl_recolor.py +665 -0
  6. requirements.txt +10 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .venv/
2
+ .vscode/
3
+ __pycache__/
README.md CHANGED
@@ -4,10 +4,10 @@ emoji: 🏢
4
  colorFrom: indigo
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
4
  colorFrom: indigo
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 4.42.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
+ This is a space made as PoC for the guide [Recoloring photos with diffusers](https://huggingface.co/blog/OzzyGT/diffusers-recolor)
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from diffusers import AutoencoderKL, ControlNetModel, TCDScheduler
5
+ from gradio_imageslider import ImageSlider
6
+ from image_gen_aux import LineArtPreprocessor
7
+ from PIL import Image, ImageEnhance
8
+
9
+ from controlnet_union import ControlNetModel_Union
10
+ from pipeline_sdxl_recolor import StableDiffusionXLRecolorPipeline
11
+
12
+ lineart_preprocessor = LineArtPreprocessor.from_pretrained("OzzyGT/lineart").to("cuda")
13
+
14
+ controlnet = [
15
+ ControlNetModel.from_pretrained(
16
+ "OzzyGT/ControlNet-recolorXL", torch_dtype=torch.float16, variant="fp16"
17
+ ),
18
+ ControlNetModel_Union.from_pretrained(
19
+ "OzzyGT/controlnet-union-promax-sdxl-1.0",
20
+ torch_dtype=torch.float16,
21
+ variant="fp16",
22
+ ),
23
+ ]
24
+
25
+ vae = AutoencoderKL.from_pretrained(
26
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
27
+ ).to("cuda")
28
+
29
+ pipe = StableDiffusionXLRecolorPipeline.from_pretrained(
30
+ "recoilme/ColorfulXL-Lightning",
31
+ torch_dtype=torch.float16,
32
+ vae=vae,
33
+ controlnet=controlnet,
34
+ variant="fp16",
35
+ ).to("cuda")
36
+
37
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
38
+
39
+ pipe.load_ip_adapter(
40
+ "h94/IP-Adapter",
41
+ subfolder="sdxl_models",
42
+ weight_name="ip-adapter_sdxl_vit-h.safetensors",
43
+ image_encoder_folder="models/image_encoder",
44
+ )
45
+
46
+ scale = {
47
+ "up": {"block_0": [1.0, 0.0, 1.0]},
48
+ }
49
+ pipe.set_ip_adapter_scale(scale)
50
+ pipe.enable_model_cpu_offload()
51
+
52
+ prompt = "high quality color photo, sharp, detailed, 4k, colorized, remastered"
53
+ negative_prompt = "blurry, low resolution, bad quality, pixelated, black and white, b&w, grayscale, monochrome, sepia"
54
+
55
+ (
56
+ prompt_embeds,
57
+ negative_prompt_embeds,
58
+ pooled_prompt_embeds,
59
+ negative_pooled_prompt_embeds,
60
+ ) = pipe.encode_prompt(prompt, negative_prompt, "cuda", True)
61
+
62
+
63
+ @spaces.GPU(duration=16)
64
+ def recolor_image(image):
65
+ source_image = image["background"]
66
+
67
+ lineart_image = lineart_preprocessor(source_image, resolution_scale=0.7)[0]
68
+
69
+ for image in pipe(
70
+ prompt_embeds=prompt_embeds,
71
+ negative_prompt_embeds=negative_prompt_embeds,
72
+ pooled_prompt_embeds=pooled_prompt_embeds,
73
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
74
+ image=[source_image, lineart_image],
75
+ ip_adapter_image=source_image,
76
+ num_inference_steps=8,
77
+ guidance_scale=2.0,
78
+ controlnet_conditioning_scale=[1.0, 0.5],
79
+ control_guidance_end=[1.0, 0.9],
80
+ ):
81
+ yield image, source_image
82
+
83
+ image = image.convert("RGBA")
84
+ source_image = source_image.convert("RGBA")
85
+
86
+ enhancer = ImageEnhance.Color(image)
87
+ image = enhancer.enhance(4.0)
88
+
89
+ alpha = image.split()[3]
90
+ alpha = alpha.point(lambda p: p * 0.20)
91
+ image.putalpha(alpha)
92
+
93
+ merged_image = Image.alpha_composite(source_image, image)
94
+
95
+ yield merged_image, source_image
96
+
97
+
98
+ def clear_result():
99
+ return gr.update(value=None)
100
+
101
+
102
+ css = """
103
+ .gradio-container {
104
+ width: 1024px !important;
105
+ }
106
+ """
107
+
108
+
109
+ title = """<h1 align="center">Diffusers Image Fill</h1>
110
+ <div align="center">Upload a grayscale image to colorize it.</div>
111
+ <div align="center">This space is a PoC made for the guide <a href='https://huggingface.co/blog/OzzyGT/diffusers-recolor'>Recoloring photos with diffusers</a>.</div>
112
+ """
113
+
114
+ with gr.Blocks(css=css) as demo:
115
+ gr.HTML(title)
116
+
117
+ run_button = gr.Button("Generate")
118
+
119
+ with gr.Row():
120
+ input_image = gr.ImageEditor(
121
+ type="pil",
122
+ label="Input Image",
123
+ crop_size=(1024, 1024),
124
+ canvas_size=(1024, 1024),
125
+ layers=False,
126
+ eraser=False,
127
+ brush=False,
128
+ sources=["upload"],
129
+ image_mode="RGB",
130
+ )
131
+
132
+ result = ImageSlider(
133
+ interactive=False,
134
+ label="Generated Image",
135
+ )
136
+
137
+ run_button.click(
138
+ fn=clear_result,
139
+ inputs=None,
140
+ outputs=result,
141
+ ).then(
142
+ fn=recolor_image,
143
+ inputs=[input_image],
144
+ outputs=result,
145
+ )
146
+
147
+
148
+ demo.launch(share=False)
controlnet_union.py ADDED
@@ -0,0 +1,1090 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from collections import OrderedDict
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.loaders import FromOriginalModelMixin
21
+ from diffusers.models.attention_processor import (
22
+ ADDED_KV_ATTENTION_PROCESSORS,
23
+ CROSS_ATTENTION_PROCESSORS,
24
+ AttentionProcessor,
25
+ AttnAddedKVProcessor,
26
+ AttnProcessor,
27
+ )
28
+ from diffusers.models.embeddings import (
29
+ TextImageProjection,
30
+ TextImageTimeEmbedding,
31
+ TextTimeEmbedding,
32
+ TimestepEmbedding,
33
+ Timesteps,
34
+ )
35
+ from diffusers.models.modeling_utils import ModelMixin
36
+ from diffusers.models.unets.unet_2d_blocks import (
37
+ CrossAttnDownBlock2D,
38
+ DownBlock2D,
39
+ UNetMidBlock2DCrossAttn,
40
+ get_down_block,
41
+ )
42
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
43
+ from diffusers.utils import BaseOutput, logging
44
+ from torch import nn
45
+ from torch.nn import functional as F
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+
50
+ # Transformer Block
51
+ # Used to exchange info between different conditions and input image
52
+ # With reference to https://github.com/TencentARC/T2I-Adapter/blob/SD/ldm/modules/encoders/adapter.py#L147
53
+ class QuickGELU(nn.Module):
54
+ def forward(self, x: torch.Tensor):
55
+ return x * torch.sigmoid(1.702 * x)
56
+
57
+
58
+ class LayerNorm(nn.LayerNorm):
59
+ """Subclass torch's LayerNorm to handle fp16."""
60
+
61
+ def forward(self, x: torch.Tensor):
62
+ orig_type = x.dtype
63
+ ret = super().forward(x)
64
+ return ret.type(orig_type)
65
+
66
+
67
+ class ResidualAttentionBlock(nn.Module):
68
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
69
+ super().__init__()
70
+
71
+ self.attn = nn.MultiheadAttention(d_model, n_head)
72
+ self.ln_1 = LayerNorm(d_model)
73
+ self.mlp = nn.Sequential(
74
+ OrderedDict(
75
+ [
76
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
77
+ ("gelu", QuickGELU()),
78
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
79
+ ]
80
+ )
81
+ )
82
+ self.ln_2 = LayerNorm(d_model)
83
+ self.attn_mask = attn_mask
84
+
85
+ def attention(self, x: torch.Tensor):
86
+ self.attn_mask = (
87
+ self.attn_mask.to(dtype=x.dtype, device=x.device)
88
+ if self.attn_mask is not None
89
+ else None
90
+ )
91
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
92
+
93
+ def forward(self, x: torch.Tensor):
94
+ x = x + self.attention(self.ln_1(x))
95
+ x = x + self.mlp(self.ln_2(x))
96
+ return x
97
+
98
+
99
+ # -----------------------------------------------------------------------------------------------------
100
+
101
+
102
+ @dataclass
103
+ class ControlNetOutput(BaseOutput):
104
+ """
105
+ The output of [`ControlNetModel`].
106
+
107
+ Args:
108
+ down_block_res_samples (`tuple[torch.Tensor]`):
109
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
110
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
111
+ used to condition the original UNet's downsampling activations.
112
+ mid_down_block_re_sample (`torch.Tensor`):
113
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
114
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
115
+ Output can be used to condition the original UNet's middle block activation.
116
+ """
117
+
118
+ down_block_res_samples: Tuple[torch.Tensor]
119
+ mid_block_res_sample: torch.Tensor
120
+
121
+
122
+ class ControlNetConditioningEmbedding(nn.Module):
123
+ """
124
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
125
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
126
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
127
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
128
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
129
+ model) to encode image-space conditions ... into feature maps ..."
130
+ """
131
+
132
+ # original setting is (16, 32, 96, 256)
133
+ def __init__(
134
+ self,
135
+ conditioning_embedding_channels: int,
136
+ conditioning_channels: int = 3,
137
+ block_out_channels: Tuple[int] = (48, 96, 192, 384),
138
+ ):
139
+ super().__init__()
140
+
141
+ self.conv_in = nn.Conv2d(
142
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
143
+ )
144
+
145
+ self.blocks = nn.ModuleList([])
146
+
147
+ for i in range(len(block_out_channels) - 1):
148
+ channel_in = block_out_channels[i]
149
+ channel_out = block_out_channels[i + 1]
150
+ self.blocks.append(
151
+ nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)
152
+ )
153
+ self.blocks.append(
154
+ nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)
155
+ )
156
+
157
+ self.conv_out = zero_module(
158
+ nn.Conv2d(
159
+ block_out_channels[-1],
160
+ conditioning_embedding_channels,
161
+ kernel_size=3,
162
+ padding=1,
163
+ )
164
+ )
165
+
166
+ def forward(self, conditioning):
167
+ embedding = self.conv_in(conditioning)
168
+ embedding = F.silu(embedding)
169
+
170
+ for block in self.blocks:
171
+ embedding = block(embedding)
172
+ embedding = F.silu(embedding)
173
+
174
+ embedding = self.conv_out(embedding)
175
+
176
+ return embedding
177
+
178
+
179
+ class ControlNetModel_Union(ModelMixin, ConfigMixin, FromOriginalModelMixin):
180
+ """
181
+ A ControlNet model.
182
+
183
+ Args:
184
+ in_channels (`int`, defaults to 4):
185
+ The number of channels in the input sample.
186
+ flip_sin_to_cos (`bool`, defaults to `True`):
187
+ Whether to flip the sin to cos in the time embedding.
188
+ freq_shift (`int`, defaults to 0):
189
+ The frequency shift to apply to the time embedding.
190
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
191
+ The tuple of downsample blocks to use.
192
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
193
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
194
+ The tuple of output channels for each block.
195
+ layers_per_block (`int`, defaults to 2):
196
+ The number of layers per block.
197
+ downsample_padding (`int`, defaults to 1):
198
+ The padding to use for the downsampling convolution.
199
+ mid_block_scale_factor (`float`, defaults to 1):
200
+ The scale factor to use for the mid block.
201
+ act_fn (`str`, defaults to "silu"):
202
+ The activation function to use.
203
+ norm_num_groups (`int`, *optional*, defaults to 32):
204
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
205
+ in post-processing.
206
+ norm_eps (`float`, defaults to 1e-5):
207
+ The epsilon to use for the normalization.
208
+ cross_attention_dim (`int`, defaults to 1280):
209
+ The dimension of the cross attention features.
210
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
211
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
212
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
213
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
214
+ encoder_hid_dim (`int`, *optional*, defaults to None):
215
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
216
+ dimension to `cross_attention_dim`.
217
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
218
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
219
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
220
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
221
+ The dimension of the attention heads.
222
+ use_linear_projection (`bool`, defaults to `False`):
223
+ class_embed_type (`str`, *optional*, defaults to `None`):
224
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
225
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
226
+ addition_embed_type (`str`, *optional*, defaults to `None`):
227
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
228
+ "text". "text" will use the `TextTimeEmbedding` layer.
229
+ num_class_embeds (`int`, *optional*, defaults to 0):
230
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
231
+ class conditioning with `class_embed_type` equal to `None`.
232
+ upcast_attention (`bool`, defaults to `False`):
233
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
234
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
235
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
236
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
237
+ `class_embed_type="projection"`.
238
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
239
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
240
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
241
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
242
+ global_pool_conditions (`bool`, defaults to `False`):
243
+ """
244
+
245
+ _supports_gradient_checkpointing = True
246
+
247
+ @register_to_config
248
+ def __init__(
249
+ self,
250
+ in_channels: int = 4,
251
+ conditioning_channels: int = 3,
252
+ flip_sin_to_cos: bool = True,
253
+ freq_shift: int = 0,
254
+ down_block_types: Tuple[str] = (
255
+ "CrossAttnDownBlock2D",
256
+ "CrossAttnDownBlock2D",
257
+ "CrossAttnDownBlock2D",
258
+ "DownBlock2D",
259
+ ),
260
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
261
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
262
+ layers_per_block: int = 2,
263
+ downsample_padding: int = 1,
264
+ mid_block_scale_factor: float = 1,
265
+ act_fn: str = "silu",
266
+ norm_num_groups: Optional[int] = 32,
267
+ norm_eps: float = 1e-5,
268
+ cross_attention_dim: int = 1280,
269
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
270
+ encoder_hid_dim: Optional[int] = None,
271
+ encoder_hid_dim_type: Optional[str] = None,
272
+ attention_head_dim: Union[int, Tuple[int]] = 8,
273
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
274
+ use_linear_projection: bool = False,
275
+ class_embed_type: Optional[str] = None,
276
+ addition_embed_type: Optional[str] = None,
277
+ addition_time_embed_dim: Optional[int] = None,
278
+ num_class_embeds: Optional[int] = None,
279
+ upcast_attention: bool = False,
280
+ resnet_time_scale_shift: str = "default",
281
+ projection_class_embeddings_input_dim: Optional[int] = None,
282
+ controlnet_conditioning_channel_order: str = "rgb",
283
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
284
+ global_pool_conditions: bool = False,
285
+ addition_embed_type_num_heads=64,
286
+ num_control_type=6,
287
+ ):
288
+ super().__init__()
289
+
290
+ # If `num_attention_heads` is not defined (which is the case for most models)
291
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
292
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
293
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
294
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
295
+ # which is why we correct for the naming here.
296
+ num_attention_heads = num_attention_heads or attention_head_dim
297
+
298
+ # Check inputs
299
+ if len(block_out_channels) != len(down_block_types):
300
+ raise ValueError(
301
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
302
+ )
303
+
304
+ if not isinstance(only_cross_attention, bool) and len(
305
+ only_cross_attention
306
+ ) != len(down_block_types):
307
+ raise ValueError(
308
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
309
+ )
310
+
311
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
312
+ down_block_types
313
+ ):
314
+ raise ValueError(
315
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
316
+ )
317
+
318
+ if isinstance(transformer_layers_per_block, int):
319
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
320
+ down_block_types
321
+ )
322
+
323
+ # input
324
+ conv_in_kernel = 3
325
+ conv_in_padding = (conv_in_kernel - 1) // 2
326
+ self.conv_in = nn.Conv2d(
327
+ in_channels,
328
+ block_out_channels[0],
329
+ kernel_size=conv_in_kernel,
330
+ padding=conv_in_padding,
331
+ )
332
+
333
+ # time
334
+ time_embed_dim = block_out_channels[0] * 4
335
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
336
+ timestep_input_dim = block_out_channels[0]
337
+ self.time_embedding = TimestepEmbedding(
338
+ timestep_input_dim,
339
+ time_embed_dim,
340
+ act_fn=act_fn,
341
+ )
342
+
343
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
344
+ encoder_hid_dim_type = "text_proj"
345
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
346
+ logger.info(
347
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
348
+ )
349
+
350
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
351
+ raise ValueError(
352
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
353
+ )
354
+
355
+ if encoder_hid_dim_type == "text_proj":
356
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
357
+ elif encoder_hid_dim_type == "text_image_proj":
358
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
359
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
360
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
361
+ self.encoder_hid_proj = TextImageProjection(
362
+ text_embed_dim=encoder_hid_dim,
363
+ image_embed_dim=cross_attention_dim,
364
+ cross_attention_dim=cross_attention_dim,
365
+ )
366
+
367
+ elif encoder_hid_dim_type is not None:
368
+ raise ValueError(
369
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
370
+ )
371
+ else:
372
+ self.encoder_hid_proj = None
373
+
374
+ # class embedding
375
+ if class_embed_type is None and num_class_embeds is not None:
376
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
377
+ elif class_embed_type == "timestep":
378
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
379
+ elif class_embed_type == "identity":
380
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
381
+ elif class_embed_type == "projection":
382
+ if projection_class_embeddings_input_dim is None:
383
+ raise ValueError(
384
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
385
+ )
386
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
387
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
388
+ # 2. it projects from an arbitrary input dimension.
389
+ #
390
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
391
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
392
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
393
+ self.class_embedding = TimestepEmbedding(
394
+ projection_class_embeddings_input_dim, time_embed_dim
395
+ )
396
+ else:
397
+ self.class_embedding = None
398
+
399
+ if addition_embed_type == "text":
400
+ if encoder_hid_dim is not None:
401
+ text_time_embedding_from_dim = encoder_hid_dim
402
+ else:
403
+ text_time_embedding_from_dim = cross_attention_dim
404
+
405
+ self.add_embedding = TextTimeEmbedding(
406
+ text_time_embedding_from_dim,
407
+ time_embed_dim,
408
+ num_heads=addition_embed_type_num_heads,
409
+ )
410
+ elif addition_embed_type == "text_image":
411
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
412
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
413
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
414
+ self.add_embedding = TextImageTimeEmbedding(
415
+ text_embed_dim=cross_attention_dim,
416
+ image_embed_dim=cross_attention_dim,
417
+ time_embed_dim=time_embed_dim,
418
+ )
419
+ elif addition_embed_type == "text_time":
420
+ self.add_time_proj = Timesteps(
421
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
422
+ )
423
+ self.add_embedding = TimestepEmbedding(
424
+ projection_class_embeddings_input_dim, time_embed_dim
425
+ )
426
+
427
+ elif addition_embed_type is not None:
428
+ raise ValueError(
429
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
430
+ )
431
+
432
+ # control net conditioning embedding
433
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
434
+ conditioning_embedding_channels=block_out_channels[0],
435
+ block_out_channels=conditioning_embedding_out_channels,
436
+ conditioning_channels=conditioning_channels,
437
+ )
438
+
439
+ # Copyright by Qi Xin(2024/07/06)
440
+ # Condition Transformer(fuse single/multi conditions with input image)
441
+ # The Condition Transformer augment the feature representation of conditions
442
+ # The overall design is somewhat like resnet. The output of Condition Transformer is used to predict a condition bias adding to the original condition feature.
443
+ # num_control_type = 6
444
+ num_trans_channel = 320
445
+ num_trans_head = 8
446
+ num_trans_layer = 1
447
+ num_proj_channel = 320
448
+ task_scale_factor = num_trans_channel**0.5
449
+
450
+ self.task_embedding = nn.Parameter(
451
+ task_scale_factor * torch.randn(num_control_type, num_trans_channel)
452
+ )
453
+ self.transformer_layes = nn.Sequential(
454
+ *[
455
+ ResidualAttentionBlock(num_trans_channel, num_trans_head)
456
+ for _ in range(num_trans_layer)
457
+ ]
458
+ )
459
+ self.spatial_ch_projs = zero_module(
460
+ nn.Linear(num_trans_channel, num_proj_channel)
461
+ )
462
+ # -----------------------------------------------------------------------------------------------------
463
+
464
+ # Copyright by Qi Xin(2024/07/06)
465
+ # Control Encoder to distinguish different control conditions
466
+ # A simple but effective module, consists of an embedding layer and a linear layer, to inject the control info to time embedding.
467
+ self.control_type_proj = Timesteps(
468
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
469
+ )
470
+ self.control_add_embedding = TimestepEmbedding(
471
+ addition_time_embed_dim * num_control_type, time_embed_dim
472
+ )
473
+ # -----------------------------------------------------------------------------------------------------
474
+
475
+ self.down_blocks = nn.ModuleList([])
476
+ self.controlnet_down_blocks = nn.ModuleList([])
477
+
478
+ if isinstance(only_cross_attention, bool):
479
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
480
+
481
+ if isinstance(attention_head_dim, int):
482
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
483
+
484
+ if isinstance(num_attention_heads, int):
485
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
486
+
487
+ # down
488
+ output_channel = block_out_channels[0]
489
+
490
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
491
+ controlnet_block = zero_module(controlnet_block)
492
+ self.controlnet_down_blocks.append(controlnet_block)
493
+
494
+ for i, down_block_type in enumerate(down_block_types):
495
+ input_channel = output_channel
496
+ output_channel = block_out_channels[i]
497
+ is_final_block = i == len(block_out_channels) - 1
498
+
499
+ down_block = get_down_block(
500
+ down_block_type,
501
+ num_layers=layers_per_block,
502
+ transformer_layers_per_block=transformer_layers_per_block[i],
503
+ in_channels=input_channel,
504
+ out_channels=output_channel,
505
+ temb_channels=time_embed_dim,
506
+ add_downsample=not is_final_block,
507
+ resnet_eps=norm_eps,
508
+ resnet_act_fn=act_fn,
509
+ resnet_groups=norm_num_groups,
510
+ cross_attention_dim=cross_attention_dim,
511
+ num_attention_heads=num_attention_heads[i],
512
+ attention_head_dim=attention_head_dim[i]
513
+ if attention_head_dim[i] is not None
514
+ else output_channel,
515
+ downsample_padding=downsample_padding,
516
+ use_linear_projection=use_linear_projection,
517
+ only_cross_attention=only_cross_attention[i],
518
+ upcast_attention=upcast_attention,
519
+ resnet_time_scale_shift=resnet_time_scale_shift,
520
+ )
521
+ self.down_blocks.append(down_block)
522
+
523
+ for _ in range(layers_per_block):
524
+ controlnet_block = nn.Conv2d(
525
+ output_channel, output_channel, kernel_size=1
526
+ )
527
+ controlnet_block = zero_module(controlnet_block)
528
+ self.controlnet_down_blocks.append(controlnet_block)
529
+
530
+ if not is_final_block:
531
+ controlnet_block = nn.Conv2d(
532
+ output_channel, output_channel, kernel_size=1
533
+ )
534
+ controlnet_block = zero_module(controlnet_block)
535
+ self.controlnet_down_blocks.append(controlnet_block)
536
+
537
+ # mid
538
+ mid_block_channel = block_out_channels[-1]
539
+
540
+ controlnet_block = nn.Conv2d(
541
+ mid_block_channel, mid_block_channel, kernel_size=1
542
+ )
543
+ controlnet_block = zero_module(controlnet_block)
544
+ self.controlnet_mid_block = controlnet_block
545
+
546
+ self.mid_block = UNetMidBlock2DCrossAttn(
547
+ transformer_layers_per_block=transformer_layers_per_block[-1],
548
+ in_channels=mid_block_channel,
549
+ temb_channels=time_embed_dim,
550
+ resnet_eps=norm_eps,
551
+ resnet_act_fn=act_fn,
552
+ output_scale_factor=mid_block_scale_factor,
553
+ resnet_time_scale_shift=resnet_time_scale_shift,
554
+ cross_attention_dim=cross_attention_dim,
555
+ num_attention_heads=num_attention_heads[-1],
556
+ resnet_groups=norm_num_groups,
557
+ use_linear_projection=use_linear_projection,
558
+ upcast_attention=upcast_attention,
559
+ )
560
+
561
+ @classmethod
562
+ def from_unet(
563
+ cls,
564
+ unet: UNet2DConditionModel,
565
+ controlnet_conditioning_channel_order: str = "rgb",
566
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
567
+ load_weights_from_unet: bool = True,
568
+ ):
569
+ r"""
570
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
571
+
572
+ Parameters:
573
+ unet (`UNet2DConditionModel`):
574
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
575
+ where applicable.
576
+ """
577
+ transformer_layers_per_block = (
578
+ unet.config.transformer_layers_per_block
579
+ if "transformer_layers_per_block" in unet.config
580
+ else 1
581
+ )
582
+ encoder_hid_dim = (
583
+ unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
584
+ )
585
+ encoder_hid_dim_type = (
586
+ unet.config.encoder_hid_dim_type
587
+ if "encoder_hid_dim_type" in unet.config
588
+ else None
589
+ )
590
+ addition_embed_type = (
591
+ unet.config.addition_embed_type
592
+ if "addition_embed_type" in unet.config
593
+ else None
594
+ )
595
+ addition_time_embed_dim = (
596
+ unet.config.addition_time_embed_dim
597
+ if "addition_time_embed_dim" in unet.config
598
+ else None
599
+ )
600
+
601
+ controlnet = cls(
602
+ encoder_hid_dim=encoder_hid_dim,
603
+ encoder_hid_dim_type=encoder_hid_dim_type,
604
+ addition_embed_type=addition_embed_type,
605
+ addition_time_embed_dim=addition_time_embed_dim,
606
+ transformer_layers_per_block=transformer_layers_per_block,
607
+ # transformer_layers_per_block=[1, 2, 5],
608
+ in_channels=unet.config.in_channels,
609
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
610
+ freq_shift=unet.config.freq_shift,
611
+ down_block_types=unet.config.down_block_types,
612
+ only_cross_attention=unet.config.only_cross_attention,
613
+ block_out_channels=unet.config.block_out_channels,
614
+ layers_per_block=unet.config.layers_per_block,
615
+ downsample_padding=unet.config.downsample_padding,
616
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
617
+ act_fn=unet.config.act_fn,
618
+ norm_num_groups=unet.config.norm_num_groups,
619
+ norm_eps=unet.config.norm_eps,
620
+ cross_attention_dim=unet.config.cross_attention_dim,
621
+ attention_head_dim=unet.config.attention_head_dim,
622
+ num_attention_heads=unet.config.num_attention_heads,
623
+ use_linear_projection=unet.config.use_linear_projection,
624
+ class_embed_type=unet.config.class_embed_type,
625
+ num_class_embeds=unet.config.num_class_embeds,
626
+ upcast_attention=unet.config.upcast_attention,
627
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
628
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
629
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
630
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
631
+ )
632
+
633
+ if load_weights_from_unet:
634
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
635
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
636
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
637
+
638
+ if controlnet.class_embedding:
639
+ controlnet.class_embedding.load_state_dict(
640
+ unet.class_embedding.state_dict()
641
+ )
642
+
643
+ controlnet.down_blocks.load_state_dict(
644
+ unet.down_blocks.state_dict(), strict=False
645
+ )
646
+ controlnet.mid_block.load_state_dict(
647
+ unet.mid_block.state_dict(), strict=False
648
+ )
649
+
650
+ return controlnet
651
+
652
+ @property
653
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
654
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
655
+ r"""
656
+ Returns:
657
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
658
+ indexed by its weight name.
659
+ """
660
+ # set recursively
661
+ processors = {}
662
+
663
+ def fn_recursive_add_processors(
664
+ name: str,
665
+ module: torch.nn.Module,
666
+ processors: Dict[str, AttentionProcessor],
667
+ ):
668
+ if hasattr(module, "get_processor"):
669
+ processors[f"{name}.processor"] = module.get_processor(
670
+ return_deprecated_lora=True
671
+ )
672
+
673
+ for sub_name, child in module.named_children():
674
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
675
+
676
+ return processors
677
+
678
+ for name, module in self.named_children():
679
+ fn_recursive_add_processors(name, module, processors)
680
+
681
+ return processors
682
+
683
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
684
+ def set_attn_processor(
685
+ self,
686
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
687
+ _remove_lora=False,
688
+ ):
689
+ r"""
690
+ Sets the attention processor to use to compute attention.
691
+
692
+ Parameters:
693
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
694
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
695
+ for **all** `Attention` layers.
696
+
697
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
698
+ processor. This is strongly recommended when setting trainable attention processors.
699
+
700
+ """
701
+ count = len(self.attn_processors.keys())
702
+
703
+ if isinstance(processor, dict) and len(processor) != count:
704
+ raise ValueError(
705
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
706
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
707
+ )
708
+
709
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
710
+ if hasattr(module, "set_processor"):
711
+ if not isinstance(processor, dict):
712
+ module.set_processor(processor, _remove_lora=_remove_lora)
713
+ else:
714
+ module.set_processor(
715
+ processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
716
+ )
717
+
718
+ for sub_name, child in module.named_children():
719
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
720
+
721
+ for name, module in self.named_children():
722
+ fn_recursive_attn_processor(name, module, processor)
723
+
724
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
725
+ def set_default_attn_processor(self):
726
+ """
727
+ Disables custom attention processors and sets the default attention implementation.
728
+ """
729
+ if all(
730
+ proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
731
+ for proc in self.attn_processors.values()
732
+ ):
733
+ processor = AttnAddedKVProcessor()
734
+ elif all(
735
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
736
+ for proc in self.attn_processors.values()
737
+ ):
738
+ processor = AttnProcessor()
739
+ else:
740
+ raise ValueError(
741
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
742
+ )
743
+
744
+ self.set_attn_processor(processor, _remove_lora=True)
745
+
746
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
747
+ def set_attention_slice(self, slice_size):
748
+ r"""
749
+ Enable sliced attention computation.
750
+
751
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
752
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
753
+
754
+ Args:
755
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
756
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
757
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
758
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
759
+ must be a multiple of `slice_size`.
760
+ """
761
+ sliceable_head_dims = []
762
+
763
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
764
+ if hasattr(module, "set_attention_slice"):
765
+ sliceable_head_dims.append(module.sliceable_head_dim)
766
+
767
+ for child in module.children():
768
+ fn_recursive_retrieve_sliceable_dims(child)
769
+
770
+ # retrieve number of attention layers
771
+ for module in self.children():
772
+ fn_recursive_retrieve_sliceable_dims(module)
773
+
774
+ num_sliceable_layers = len(sliceable_head_dims)
775
+
776
+ if slice_size == "auto":
777
+ # half the attention head size is usually a good trade-off between
778
+ # speed and memory
779
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
780
+ elif slice_size == "max":
781
+ # make smallest slice possible
782
+ slice_size = num_sliceable_layers * [1]
783
+
784
+ slice_size = (
785
+ num_sliceable_layers * [slice_size]
786
+ if not isinstance(slice_size, list)
787
+ else slice_size
788
+ )
789
+
790
+ if len(slice_size) != len(sliceable_head_dims):
791
+ raise ValueError(
792
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
793
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
794
+ )
795
+
796
+ for i in range(len(slice_size)):
797
+ size = slice_size[i]
798
+ dim = sliceable_head_dims[i]
799
+ if size is not None and size > dim:
800
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
801
+
802
+ # Recursively walk through all the children.
803
+ # Any children which exposes the set_attention_slice method
804
+ # gets the message
805
+ def fn_recursive_set_attention_slice(
806
+ module: torch.nn.Module, slice_size: List[int]
807
+ ):
808
+ if hasattr(module, "set_attention_slice"):
809
+ module.set_attention_slice(slice_size.pop())
810
+
811
+ for child in module.children():
812
+ fn_recursive_set_attention_slice(child, slice_size)
813
+
814
+ reversed_slice_size = list(reversed(slice_size))
815
+ for module in self.children():
816
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
817
+
818
+ def _set_gradient_checkpointing(self, module, value=False):
819
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
820
+ module.gradient_checkpointing = value
821
+
822
+ def forward(
823
+ self,
824
+ sample: torch.FloatTensor,
825
+ timestep: Union[torch.Tensor, float, int],
826
+ encoder_hidden_states: torch.Tensor,
827
+ controlnet_cond: torch.FloatTensor,
828
+ conditioning_scale: float = 1.0,
829
+ class_labels: Optional[torch.Tensor] = None,
830
+ timestep_cond: Optional[torch.Tensor] = None,
831
+ attention_mask: Optional[torch.Tensor] = None,
832
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
833
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
834
+ guess_mode: bool = False,
835
+ return_dict: bool = True,
836
+ ) -> Union[ControlNetOutput, Tuple]:
837
+ """
838
+ The [`ControlNetModel`] forward method.
839
+
840
+ Args:
841
+ sample (`torch.FloatTensor`):
842
+ The noisy input tensor.
843
+ timestep (`Union[torch.Tensor, float, int]`):
844
+ The number of timesteps to denoise an input.
845
+ encoder_hidden_states (`torch.Tensor`):
846
+ The encoder hidden states.
847
+ controlnet_cond (`torch.FloatTensor`):
848
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
849
+ conditioning_scale (`float`, defaults to `1.0`):
850
+ The scale factor for ControlNet outputs.
851
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
852
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
853
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
854
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
855
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
856
+ embeddings.
857
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
858
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
859
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
860
+ negative values to the attention scores corresponding to "discard" tokens.
861
+ added_cond_kwargs (`dict`):
862
+ Additional conditions for the Stable Diffusion XL UNet.
863
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
864
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
865
+ guess_mode (`bool`, defaults to `False`):
866
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
867
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
868
+ return_dict (`bool`, defaults to `True`):
869
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
870
+
871
+ Returns:
872
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
873
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
874
+ returned where the first element is the sample tensor.
875
+ """
876
+ # check channel order
877
+ channel_order = self.config.controlnet_conditioning_channel_order
878
+
879
+ if channel_order == "rgb":
880
+ # in rgb order by default
881
+ ...
882
+ # elif channel_order == "bgr":
883
+ # controlnet_cond = torch.flip(controlnet_cond, dims=[1])
884
+ else:
885
+ raise ValueError(
886
+ f"unknown `controlnet_conditioning_channel_order`: {channel_order}"
887
+ )
888
+
889
+ # prepare attention_mask
890
+ if attention_mask is not None:
891
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
892
+ attention_mask = attention_mask.unsqueeze(1)
893
+
894
+ # 1. time
895
+ timesteps = timestep
896
+ if not torch.is_tensor(timesteps):
897
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
898
+ # This would be a good case for the `match` statement (Python 3.10+)
899
+ is_mps = sample.device.type == "mps"
900
+ if isinstance(timestep, float):
901
+ dtype = torch.float32 if is_mps else torch.float64
902
+ else:
903
+ dtype = torch.int32 if is_mps else torch.int64
904
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
905
+ elif len(timesteps.shape) == 0:
906
+ timesteps = timesteps[None].to(sample.device)
907
+
908
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
909
+ timesteps = timesteps.expand(sample.shape[0])
910
+
911
+ t_emb = self.time_proj(timesteps)
912
+
913
+ # timesteps does not contain any weights and will always return f32 tensors
914
+ # but time_embedding might actually be running in fp16. so we need to cast here.
915
+ # there might be better ways to encapsulate this.
916
+ t_emb = t_emb.to(dtype=sample.dtype)
917
+
918
+ emb = self.time_embedding(t_emb, timestep_cond)
919
+ aug_emb = None
920
+
921
+ if self.class_embedding is not None:
922
+ if class_labels is None:
923
+ raise ValueError(
924
+ "class_labels should be provided when num_class_embeds > 0"
925
+ )
926
+
927
+ if self.config.class_embed_type == "timestep":
928
+ class_labels = self.time_proj(class_labels)
929
+
930
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
931
+ emb = emb + class_emb
932
+
933
+ if self.config.addition_embed_type is not None:
934
+ if self.config.addition_embed_type == "text":
935
+ aug_emb = self.add_embedding(encoder_hidden_states)
936
+
937
+ elif self.config.addition_embed_type == "text_time":
938
+ if "text_embeds" not in added_cond_kwargs:
939
+ raise ValueError(
940
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
941
+ )
942
+ text_embeds = added_cond_kwargs.get("text_embeds")
943
+ if "time_ids" not in added_cond_kwargs:
944
+ raise ValueError(
945
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
946
+ )
947
+ time_ids = added_cond_kwargs.get("time_ids")
948
+ time_embeds = self.add_time_proj(time_ids.flatten())
949
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
950
+
951
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
952
+ add_embeds = add_embeds.to(emb.dtype)
953
+ aug_emb = self.add_embedding(add_embeds)
954
+
955
+ # Copyright by Qi Xin(2024/07/06)
956
+ # inject control type info to time embedding to distinguish different control conditions
957
+ control_type = (
958
+ torch.Tensor([0, 0, 0, 1, 0, 0, 0, 0])
959
+ .to(emb.device, dtype=emb.dtype)
960
+ .repeat(controlnet_cond.shape[0], 1)
961
+ )
962
+ control_embeds = self.control_type_proj(control_type.flatten())
963
+ control_embeds = control_embeds.reshape((t_emb.shape[0], -1))
964
+ control_embeds = control_embeds.to(emb.dtype)
965
+ control_emb = self.control_add_embedding(control_embeds)
966
+ emb = emb + control_emb
967
+ # ---------------------------------------------------------------------------------
968
+
969
+ emb = emb + aug_emb if aug_emb is not None else emb
970
+
971
+ # 2. pre-process
972
+ sample = self.conv_in(sample)
973
+ indices = torch.nonzero(control_type[0])
974
+
975
+ # Copyright by Qi Xin(2024/07/06)
976
+ # add single/multi conditons to input image.
977
+ # Condition Transformer provides an easy and effective way to fuse different features naturally
978
+ inputs = []
979
+ condition_list = []
980
+ controlnet_cond_list = [0, 0, 0, controlnet_cond, 0, 0, 0, 0]
981
+
982
+ for idx in range(indices.shape[0] + 1):
983
+ if idx == indices.shape[0]:
984
+ single_controlnet_cond = sample
985
+ feat_seq = torch.mean(single_controlnet_cond, dim=(2, 3)) # N * C
986
+ else:
987
+ single_controlnet_cond = self.controlnet_cond_embedding(
988
+ controlnet_cond_list[indices[idx][0]]
989
+ )
990
+ feat_seq = torch.mean(single_controlnet_cond, dim=(2, 3)) # N * C
991
+ feat_seq = feat_seq + self.task_embedding[indices[idx][0]]
992
+
993
+ inputs.append(feat_seq.unsqueeze(1))
994
+ condition_list.append(single_controlnet_cond)
995
+
996
+ x = torch.cat(inputs, dim=1) # NxLxC
997
+ x = self.transformer_layes(x)
998
+
999
+ controlnet_cond_fuser = sample * 0.0
1000
+ for idx in range(indices.shape[0]):
1001
+ alpha = self.spatial_ch_projs(x[:, idx])
1002
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1)
1003
+ controlnet_cond_fuser += condition_list[idx] + alpha
1004
+
1005
+ sample = sample + controlnet_cond_fuser
1006
+ # -------------------------------------------------------------------------------------------
1007
+
1008
+ # 3. down
1009
+ down_block_res_samples = (sample,)
1010
+ for downsample_block in self.down_blocks:
1011
+ if (
1012
+ hasattr(downsample_block, "has_cross_attention")
1013
+ and downsample_block.has_cross_attention
1014
+ ):
1015
+ sample, res_samples = downsample_block(
1016
+ hidden_states=sample,
1017
+ temb=emb,
1018
+ encoder_hidden_states=encoder_hidden_states,
1019
+ attention_mask=attention_mask,
1020
+ cross_attention_kwargs=cross_attention_kwargs,
1021
+ )
1022
+ else:
1023
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1024
+
1025
+ down_block_res_samples += res_samples
1026
+
1027
+ # 4. mid
1028
+ if self.mid_block is not None:
1029
+ sample = self.mid_block(
1030
+ sample,
1031
+ emb,
1032
+ encoder_hidden_states=encoder_hidden_states,
1033
+ attention_mask=attention_mask,
1034
+ cross_attention_kwargs=cross_attention_kwargs,
1035
+ )
1036
+
1037
+ # 5. Control net blocks
1038
+
1039
+ controlnet_down_block_res_samples = ()
1040
+
1041
+ for down_block_res_sample, controlnet_block in zip(
1042
+ down_block_res_samples, self.controlnet_down_blocks
1043
+ ):
1044
+ down_block_res_sample = controlnet_block(down_block_res_sample)
1045
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (
1046
+ down_block_res_sample,
1047
+ )
1048
+
1049
+ down_block_res_samples = controlnet_down_block_res_samples
1050
+
1051
+ mid_block_res_sample = self.controlnet_mid_block(sample)
1052
+
1053
+ # 6. scaling
1054
+ if guess_mode and not self.config.global_pool_conditions:
1055
+ scales = torch.logspace(
1056
+ -1, 0, len(down_block_res_samples) + 1, device=sample.device
1057
+ ) # 0.1 to 1.0
1058
+ scales = scales * conditioning_scale
1059
+ down_block_res_samples = [
1060
+ sample * scale for sample, scale in zip(down_block_res_samples, scales)
1061
+ ]
1062
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
1063
+ else:
1064
+ down_block_res_samples = [
1065
+ sample * conditioning_scale for sample in down_block_res_samples
1066
+ ]
1067
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
1068
+
1069
+ if self.config.global_pool_conditions:
1070
+ down_block_res_samples = [
1071
+ torch.mean(sample, dim=(2, 3), keepdim=True)
1072
+ for sample in down_block_res_samples
1073
+ ]
1074
+ mid_block_res_sample = torch.mean(
1075
+ mid_block_res_sample, dim=(2, 3), keepdim=True
1076
+ )
1077
+
1078
+ if not return_dict:
1079
+ return (down_block_res_samples, mid_block_res_sample)
1080
+
1081
+ return ControlNetOutput(
1082
+ down_block_res_samples=down_block_res_samples,
1083
+ mid_block_res_sample=mid_block_res_sample,
1084
+ )
1085
+
1086
+
1087
+ def zero_module(module):
1088
+ for p in module.parameters():
1089
+ nn.init.zeros_(p)
1090
+ return module
pipeline_sdxl_recolor.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import cv2
19
+ import PIL
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
23
+ from diffusers.loaders import (
24
+ FromSingleFileMixin,
25
+ IPAdapterMixin,
26
+ StableDiffusionXLLoraLoaderMixin,
27
+ TextualInversionLoaderMixin,
28
+ )
29
+ from diffusers.models import (
30
+ AutoencoderKL,
31
+ ControlNetModel,
32
+ ImageProjection,
33
+ UNet2DConditionModel,
34
+ )
35
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
36
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
37
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import (
38
+ StableDiffusionXLPipelineOutput,
39
+ )
40
+ from diffusers.schedulers import KarrasDiffusionSchedulers
41
+ from diffusers.utils.torch_utils import randn_tensor
42
+ from transformers import (
43
+ CLIPImageProcessor,
44
+ CLIPTextModel,
45
+ CLIPTextModelWithProjection,
46
+ CLIPTokenizer,
47
+ CLIPVisionModelWithProjection,
48
+ )
49
+
50
+
51
+ def latents_to_rgb(latents):
52
+ weights = ((60, -60, 25, -70), (60, -5, 15, -50), (60, 10, -5, -35))
53
+
54
+ weights_tensor = torch.t(
55
+ torch.tensor(weights, dtype=latents.dtype).to(latents.device)
56
+ )
57
+ biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(
58
+ latents.device
59
+ )
60
+ rgb_tensor = torch.einsum(
61
+ "...lxy,lr -> ...rxy", latents, weights_tensor
62
+ ) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
63
+ image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
64
+ image_array = image_array.transpose(1, 2, 0) # Change the order of dimensions
65
+
66
+ denoised_image = cv2.fastNlMeansDenoisingColored(image_array, None, 10, 10, 7, 21)
67
+ blurred_image = cv2.GaussianBlur(denoised_image, (5, 5), 0)
68
+ final_image = PIL.Image.fromarray(blurred_image)
69
+
70
+ width, height = final_image.size
71
+ final_image = final_image.resize(
72
+ (width * 8, height * 8), PIL.Image.Resampling.LANCZOS
73
+ )
74
+
75
+ return final_image
76
+
77
+
78
+ def retrieve_timesteps(
79
+ scheduler,
80
+ num_inference_steps: Optional[int] = None,
81
+ device: Optional[Union[str, torch.device]] = None,
82
+ **kwargs,
83
+ ):
84
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
85
+ timesteps = scheduler.timesteps
86
+
87
+ return timesteps, num_inference_steps
88
+
89
+
90
+ class StableDiffusionXLRecolorPipeline(
91
+ DiffusionPipeline,
92
+ StableDiffusionMixin,
93
+ TextualInversionLoaderMixin,
94
+ StableDiffusionXLLoraLoaderMixin,
95
+ IPAdapterMixin,
96
+ FromSingleFileMixin,
97
+ ):
98
+ # leave controlnet out on purpose because it iterates with unet
99
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
100
+ _optional_components = [
101
+ "tokenizer",
102
+ "tokenizer_2",
103
+ "text_encoder",
104
+ "text_encoder_2",
105
+ "feature_extractor",
106
+ "image_encoder",
107
+ ]
108
+ _callback_tensor_inputs = [
109
+ "latents",
110
+ "prompt_embeds",
111
+ "negative_prompt_embeds",
112
+ "add_text_embeds",
113
+ "add_time_ids",
114
+ "negative_pooled_prompt_embeds",
115
+ "negative_add_time_ids",
116
+ ]
117
+
118
+ def __init__(
119
+ self,
120
+ vae: AutoencoderKL,
121
+ text_encoder: CLIPTextModel,
122
+ text_encoder_2: CLIPTextModelWithProjection,
123
+ tokenizer: CLIPTokenizer,
124
+ tokenizer_2: CLIPTokenizer,
125
+ unet: UNet2DConditionModel,
126
+ controlnet: Union[
127
+ ControlNetModel,
128
+ List[ControlNetModel],
129
+ Tuple[ControlNetModel],
130
+ MultiControlNetModel,
131
+ ],
132
+ scheduler: KarrasDiffusionSchedulers,
133
+ force_zeros_for_empty_prompt: bool = True,
134
+ add_watermarker: Optional[bool] = None,
135
+ feature_extractor: CLIPImageProcessor = None,
136
+ image_encoder: CLIPVisionModelWithProjection = None,
137
+ ):
138
+ super().__init__()
139
+
140
+ if isinstance(controlnet, (list, tuple)):
141
+ controlnet = MultiControlNetModel(controlnet)
142
+
143
+ self.register_modules(
144
+ vae=vae,
145
+ text_encoder=text_encoder,
146
+ text_encoder_2=text_encoder_2,
147
+ tokenizer=tokenizer,
148
+ tokenizer_2=tokenizer_2,
149
+ unet=unet,
150
+ controlnet=controlnet,
151
+ scheduler=scheduler,
152
+ feature_extractor=feature_extractor,
153
+ image_encoder=image_encoder,
154
+ )
155
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
156
+ self.image_processor = VaeImageProcessor(
157
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
158
+ )
159
+ self.control_image_processor = VaeImageProcessor(
160
+ vae_scale_factor=self.vae_scale_factor,
161
+ do_convert_rgb=True,
162
+ do_normalize=False,
163
+ )
164
+ self.register_to_config(
165
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt
166
+ )
167
+
168
+ def encode_prompt(
169
+ self,
170
+ prompt: str,
171
+ negative_prompt: Optional[str] = None,
172
+ device: Optional[torch.device] = None,
173
+ do_classifier_free_guidance: bool = True,
174
+ ):
175
+ device = device or self._execution_device
176
+ prompt = [prompt] if isinstance(prompt, str) else prompt
177
+
178
+ if prompt is not None:
179
+ batch_size = len(prompt)
180
+
181
+ # Define tokenizers and text encoders
182
+ tokenizers = (
183
+ [self.tokenizer, self.tokenizer_2]
184
+ if self.tokenizer is not None
185
+ else [self.tokenizer_2]
186
+ )
187
+ text_encoders = (
188
+ [self.text_encoder, self.text_encoder_2]
189
+ if self.text_encoder is not None
190
+ else [self.text_encoder_2]
191
+ )
192
+
193
+ prompt_2 = prompt
194
+
195
+ # textual inversion: process multi-vector tokens if necessary
196
+ prompt_embeds_list = []
197
+ prompts = [prompt, prompt_2]
198
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
199
+ text_inputs = tokenizer(
200
+ prompt,
201
+ padding="max_length",
202
+ max_length=tokenizer.model_max_length,
203
+ truncation=True,
204
+ return_tensors="pt",
205
+ )
206
+
207
+ text_input_ids = text_inputs.input_ids
208
+
209
+ prompt_embeds = text_encoder(
210
+ text_input_ids.to(device), output_hidden_states=True
211
+ )
212
+
213
+ # We are only ALWAYS interested in the pooled output of the final text encoder
214
+ pooled_prompt_embeds = prompt_embeds[0]
215
+ prompt_embeds = prompt_embeds.hidden_states[-2]
216
+ prompt_embeds_list.append(prompt_embeds)
217
+
218
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
219
+
220
+ # get unconditional embeddings for classifier free guidance
221
+ negative_prompt_embeds = None
222
+ negative_pooled_prompt_embeds = None
223
+
224
+ if do_classifier_free_guidance:
225
+ negative_prompt = negative_prompt or ""
226
+
227
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
228
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
229
+
230
+ # normalize str to list
231
+ negative_prompt = [negative_prompt]
232
+ negative_prompt_2 = negative_prompt
233
+
234
+ uncond_tokens: List[str]
235
+ uncond_tokens = [negative_prompt, negative_prompt_2]
236
+
237
+ negative_prompt_embeds_list = []
238
+ for negative_prompt, tokenizer, text_encoder in zip(
239
+ uncond_tokens, tokenizers, text_encoders
240
+ ):
241
+ max_length = prompt_embeds.shape[1]
242
+ uncond_input = tokenizer(
243
+ negative_prompt,
244
+ padding="max_length",
245
+ max_length=max_length,
246
+ truncation=True,
247
+ return_tensors="pt",
248
+ )
249
+
250
+ negative_prompt_embeds = text_encoder(
251
+ uncond_input.input_ids.to(device),
252
+ output_hidden_states=True,
253
+ )
254
+ # We are only ALWAYS interested in the pooled output of the final text encoder
255
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
256
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
257
+
258
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
259
+
260
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
261
+
262
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
263
+
264
+ bs_embed, seq_len, _ = prompt_embeds.shape
265
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
266
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
267
+
268
+ if do_classifier_free_guidance:
269
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
270
+ seq_len = negative_prompt_embeds.shape[1]
271
+
272
+ negative_prompt_embeds = negative_prompt_embeds.to(
273
+ dtype=self.text_encoder_2.dtype, device=device
274
+ )
275
+
276
+ negative_prompt_embeds = negative_prompt_embeds.view(
277
+ batch_size, seq_len, -1
278
+ )
279
+
280
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
281
+
282
+ if do_classifier_free_guidance:
283
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(
284
+ bs_embed, -1
285
+ )
286
+
287
+ return (
288
+ prompt_embeds,
289
+ negative_prompt_embeds,
290
+ pooled_prompt_embeds,
291
+ negative_pooled_prompt_embeds,
292
+ )
293
+
294
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
295
+ def encode_image(
296
+ self, image, device, num_images_per_prompt, output_hidden_states=None
297
+ ):
298
+ dtype = next(self.image_encoder.parameters()).dtype
299
+
300
+ if not isinstance(image, torch.Tensor):
301
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
302
+
303
+ image = image.to(device=device, dtype=dtype)
304
+ if output_hidden_states:
305
+ image_enc_hidden_states = self.image_encoder(
306
+ image, output_hidden_states=True
307
+ ).hidden_states[-2]
308
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(
309
+ num_images_per_prompt, dim=0
310
+ )
311
+ uncond_image_enc_hidden_states = self.image_encoder(
312
+ torch.zeros_like(image), output_hidden_states=True
313
+ ).hidden_states[-2]
314
+ uncond_image_enc_hidden_states = (
315
+ uncond_image_enc_hidden_states.repeat_interleave(
316
+ num_images_per_prompt, dim=0
317
+ )
318
+ )
319
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
320
+ else:
321
+ image_embeds = self.image_encoder(image).image_embeds
322
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
323
+ uncond_image_embeds = torch.zeros_like(image_embeds)
324
+
325
+ return image_embeds, uncond_image_embeds
326
+
327
+ def prepare_ip_adapter_image_embeds(
328
+ self,
329
+ ip_adapter_image,
330
+ device,
331
+ do_classifier_free_guidance,
332
+ ):
333
+ image_embeds = []
334
+ if do_classifier_free_guidance:
335
+ negative_image_embeds = []
336
+
337
+ if not isinstance(ip_adapter_image, list):
338
+ ip_adapter_image = [ip_adapter_image]
339
+
340
+ if len(ip_adapter_image) != len(
341
+ self.unet.encoder_hid_proj.image_projection_layers
342
+ ):
343
+ raise ValueError(
344
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
345
+ )
346
+
347
+ for single_ip_adapter_image, image_proj_layer in zip(
348
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
349
+ ):
350
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
351
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
352
+ single_ip_adapter_image, device, 1, output_hidden_state
353
+ )
354
+
355
+ image_embeds.append(single_image_embeds[None, :])
356
+ if do_classifier_free_guidance:
357
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
358
+
359
+ ip_adapter_image_embeds = []
360
+
361
+ for i, single_image_embeds in enumerate(image_embeds):
362
+ if do_classifier_free_guidance:
363
+ single_image_embeds = torch.cat(
364
+ [negative_image_embeds[i], single_image_embeds], dim=0
365
+ )
366
+
367
+ single_image_embeds = single_image_embeds.to(device=device)
368
+ ip_adapter_image_embeds.append(single_image_embeds)
369
+
370
+ return ip_adapter_image_embeds
371
+
372
+ def prepare_image(self, image, device, dtype, do_classifier_free_guidance=False):
373
+ image = self.control_image_processor.preprocess(image).to(dtype=torch.float32)
374
+
375
+ image_batch_size = image.shape[0]
376
+
377
+ image = image.repeat_interleave(image_batch_size, dim=0)
378
+ image = image.to(device=device, dtype=dtype)
379
+
380
+ if do_classifier_free_guidance:
381
+ image = torch.cat([image] * 2)
382
+
383
+ return image
384
+
385
+ def prepare_latents(
386
+ self, batch_size, num_channels_latents, height, width, dtype, device
387
+ ):
388
+ shape = (
389
+ batch_size,
390
+ num_channels_latents,
391
+ int(height) // self.vae_scale_factor,
392
+ int(width) // self.vae_scale_factor,
393
+ )
394
+
395
+ latents = randn_tensor(shape, device=device, dtype=dtype)
396
+
397
+ # scale the initial noise by the standard deviation required by the scheduler
398
+ latents = latents * self.scheduler.init_noise_sigma
399
+ return latents
400
+
401
+ @property
402
+ def guidance_scale(self):
403
+ return self._guidance_scale
404
+
405
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
406
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
407
+ # corresponds to doing no classifier free guidance.
408
+ @property
409
+ def do_classifier_free_guidance(self):
410
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
411
+
412
+ @property
413
+ def denoising_end(self):
414
+ return self._denoising_end
415
+
416
+ @property
417
+ def num_timesteps(self):
418
+ return self._num_timesteps
419
+
420
+ @torch.no_grad()
421
+ def __call__(
422
+ self,
423
+ image: PipelineImageInput = None,
424
+ num_inference_steps: int = 8,
425
+ guidance_scale: float = 2.0,
426
+ prompt_embeds: Optional[torch.Tensor] = None,
427
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
428
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
429
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
430
+ ip_adapter_image: Optional[PipelineImageInput] = None,
431
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
432
+ control_guidance_start: Union[float, List[float]] = 0.0,
433
+ control_guidance_end: Union[float, List[float]] = 1.0,
434
+ **kwargs,
435
+ ):
436
+ controlnet = self.controlnet
437
+
438
+ # align format for control guidance
439
+ if not isinstance(control_guidance_start, list) and isinstance(
440
+ control_guidance_end, list
441
+ ):
442
+ control_guidance_start = len(control_guidance_end) * [
443
+ control_guidance_start
444
+ ]
445
+ elif not isinstance(control_guidance_end, list) and isinstance(
446
+ control_guidance_start, list
447
+ ):
448
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
449
+ elif not isinstance(control_guidance_start, list) and not isinstance(
450
+ control_guidance_end, list
451
+ ):
452
+ mult = (
453
+ len(controlnet.nets)
454
+ if isinstance(controlnet, MultiControlNetModel)
455
+ else 1
456
+ )
457
+ control_guidance_start, control_guidance_end = (
458
+ mult * [control_guidance_start],
459
+ mult * [control_guidance_end],
460
+ )
461
+
462
+ self._guidance_scale = guidance_scale
463
+
464
+ # 2. Define call parameters
465
+ batch_size = 1
466
+ device = self._execution_device
467
+
468
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(
469
+ controlnet_conditioning_scale, float
470
+ ):
471
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(
472
+ controlnet.nets
473
+ )
474
+
475
+ # 3.2 Encode ip_adapter_image
476
+ if ip_adapter_image is not None:
477
+ image_embeds = self.prepare_ip_adapter_image_embeds(
478
+ ip_adapter_image,
479
+ device,
480
+ self.do_classifier_free_guidance,
481
+ )
482
+
483
+ # 4. Prepare image
484
+ if isinstance(controlnet, ControlNetModel):
485
+ image = self.prepare_image(
486
+ image=image,
487
+ device=device,
488
+ dtype=controlnet.dtype,
489
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
490
+ )
491
+ height, width = image.shape[-2:]
492
+ elif isinstance(controlnet, MultiControlNetModel):
493
+ images = []
494
+
495
+ for image_ in image:
496
+ image_ = self.prepare_image(
497
+ image=image_,
498
+ device=device,
499
+ dtype=controlnet.dtype,
500
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
501
+ )
502
+
503
+ images.append(image_)
504
+
505
+ image = images
506
+ height, width = image[0].shape[-2:]
507
+ else:
508
+ assert False
509
+
510
+ # 5. Prepare timesteps
511
+ timesteps, num_inference_steps = retrieve_timesteps(
512
+ self.scheduler, num_inference_steps, device
513
+ )
514
+ self._num_timesteps = len(timesteps)
515
+
516
+ # 6. Prepare latent variables
517
+ num_channels_latents = self.unet.config.in_channels
518
+ latents = self.prepare_latents(
519
+ batch_size,
520
+ num_channels_latents,
521
+ height,
522
+ width,
523
+ prompt_embeds.dtype,
524
+ device,
525
+ )
526
+
527
+ # 7.1 Create tensor stating which controlnets to keep
528
+ controlnet_keep = []
529
+ for i in range(len(timesteps)):
530
+ keeps = [
531
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
532
+ for s, e in zip(control_guidance_start, control_guidance_end)
533
+ ]
534
+ controlnet_keep.append(
535
+ keeps[0] if isinstance(controlnet, ControlNetModel) else keeps
536
+ )
537
+
538
+ # 7.2 Prepare added time ids & embeddings
539
+ add_text_embeds = pooled_prompt_embeds
540
+
541
+ add_time_ids = negative_add_time_ids = torch.tensor(
542
+ image[0].shape[-2:] + torch.Size([0, 0]) + image[0].shape[-2:]
543
+ ).unsqueeze(0)
544
+
545
+ negative_add_time_ids = add_time_ids
546
+
547
+ if self.do_classifier_free_guidance:
548
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
549
+ add_text_embeds = torch.cat(
550
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0
551
+ )
552
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
553
+
554
+ prompt_embeds = prompt_embeds.to(device)
555
+ add_text_embeds = add_text_embeds.to(device)
556
+ add_time_ids = add_time_ids.to(device)
557
+
558
+ added_cond_kwargs = {
559
+ "text_embeds": add_text_embeds,
560
+ "time_ids": add_time_ids,
561
+ }
562
+
563
+ # 8. Denoising loop
564
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
565
+
566
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
567
+ for i, t in enumerate(timesteps):
568
+ # expand the latents if we are doing classifier free guidance
569
+ latent_model_input = (
570
+ torch.cat([latents] * 2)
571
+ if self.do_classifier_free_guidance
572
+ else latents
573
+ )
574
+ latent_model_input = self.scheduler.scale_model_input(
575
+ latent_model_input, t
576
+ )
577
+
578
+ # controlnet(s) inference
579
+ control_model_input = latent_model_input
580
+ controlnet_prompt_embeds = prompt_embeds
581
+ controlnet_added_cond_kwargs = added_cond_kwargs
582
+
583
+ if isinstance(controlnet_keep[i], list):
584
+ cond_scale = [
585
+ c * s
586
+ for c, s in zip(
587
+ controlnet_conditioning_scale, controlnet_keep[i]
588
+ )
589
+ ]
590
+ else:
591
+ controlnet_cond_scale = controlnet_conditioning_scale
592
+ if isinstance(controlnet_cond_scale, list):
593
+ controlnet_cond_scale = controlnet_cond_scale[0]
594
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
595
+
596
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
597
+ control_model_input,
598
+ t,
599
+ encoder_hidden_states=controlnet_prompt_embeds,
600
+ controlnet_cond=image,
601
+ conditioning_scale=cond_scale,
602
+ guess_mode=False,
603
+ added_cond_kwargs=controlnet_added_cond_kwargs,
604
+ return_dict=False,
605
+ )
606
+
607
+ if ip_adapter_image is not None:
608
+ added_cond_kwargs["image_embeds"] = image_embeds
609
+
610
+ # predict the noise residual
611
+ noise_pred = self.unet(
612
+ latent_model_input,
613
+ t,
614
+ encoder_hidden_states=prompt_embeds,
615
+ timestep_cond=None,
616
+ cross_attention_kwargs={},
617
+ down_block_additional_residuals=down_block_res_samples,
618
+ mid_block_additional_residual=mid_block_res_sample,
619
+ added_cond_kwargs=added_cond_kwargs,
620
+ return_dict=False,
621
+ )[0]
622
+
623
+ # perform guidance
624
+ if self.do_classifier_free_guidance:
625
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
626
+ noise_pred = noise_pred_uncond + guidance_scale * (
627
+ noise_pred_text - noise_pred_uncond
628
+ )
629
+
630
+ # compute the previous noisy sample x_t -> x_t-1
631
+ latents = self.scheduler.step(
632
+ noise_pred, t, latents, return_dict=False
633
+ )[0]
634
+
635
+ if i == 2:
636
+ prompt_embeds = prompt_embeds[-1:]
637
+ add_text_embeds = add_text_embeds[-1:]
638
+ add_time_ids = add_time_ids[-1:]
639
+
640
+ added_cond_kwargs = {
641
+ "text_embeds": add_text_embeds,
642
+ "time_ids": add_time_ids,
643
+ }
644
+
645
+ controlnet_prompt_embeds = prompt_embeds
646
+ controlnet_added_cond_kwargs = added_cond_kwargs
647
+
648
+ image = [single_image[-1:] for single_image in image]
649
+ self._guidance_scale = 0.0
650
+
651
+ # call the callback, if provided
652
+ if i == len(timesteps) - 1 or (
653
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
654
+ ):
655
+ progress_bar.update()
656
+ yield latents_to_rgb(latents)
657
+
658
+ latents = latents / self.vae.config.scaling_factor
659
+ image = self.vae.decode(latents, return_dict=False)[0]
660
+ image = self.image_processor.postprocess(image)[0]
661
+
662
+ # Offload all models
663
+ self.maybe_free_model_hooks()
664
+
665
+ yield image
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ spaces
3
+ gradio==4.42.0
4
+ gradio-imageslider
5
+ numpy==1.26.4
6
+ transformers
7
+ accelerate
8
+ diffusers
9
+ fastapi<0.113.0
10
+ git+https://github.com/asomoza/image_gen_aux.git