Stable-X commited on
Commit
e39296c
1 Parent(s): 66c7cc9

Upload dino_controlnetvae.py

Browse files
Files changed (1) hide show
  1. dino_controlnet/dino_controlnetvae.py +403 -0
dino_controlnet/dino_controlnetvae.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ AttentionProcessor,
28
+ AttnAddedKVProcessor,
29
+ AttnProcessor,
30
+ )
31
+ from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.unets.unet_2d_blocks import (
34
+ CrossAttnDownBlock2D,
35
+ DownBlock2D,
36
+ UNetMidBlock2D,
37
+ UNetMidBlock2DCrossAttn,
38
+ get_down_block,
39
+ )
40
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
41
+ from diffusers.models.controlnet import ControlNetOutput
42
+ from diffusers.models import ControlNetModel
43
+
44
+ import pdb
45
+
46
+
47
+ def conv_nd(dims, *args, **kwargs):
48
+ """
49
+ Create a 1D, 2D, or 3D convolution module.
50
+ """
51
+ if dims == 1:
52
+ return nn.Conv1d(*args, **kwargs)
53
+ elif dims == 2:
54
+ return nn.Conv2d(*args, **kwargs)
55
+ elif dims == 3:
56
+ return nn.Conv3d(*args, **kwargs)
57
+ raise ValueError(f"unsupported dimensions: {dims}")
58
+
59
+ def zero_module(module):
60
+ """
61
+ Zero out the parameters of a module and return it.
62
+ """
63
+ for p in module.parameters():
64
+ p.detach().zero_()
65
+ return module
66
+
67
+
68
+ class DINOControlNetConditioningEmbedding(nn.Module):
69
+ def __init__(
70
+ self,
71
+ conditioning_embedding_channels: int,
72
+ conditioning_channels: int = 3,
73
+ block_out_channels = (16, 32, 64, 128),
74
+ up_sampling='transpose'
75
+ ):
76
+ super().__init__()
77
+
78
+ self.conv_in = conv_nd(
79
+ 2, conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
80
+ )
81
+
82
+ self.blocks = nn.ModuleList([])
83
+
84
+ for i in range(len(block_out_channels) - 1):
85
+ channel_in = block_out_channels[i]
86
+ channel_out = block_out_channels[i + 1]
87
+ self.blocks.append(
88
+ conv_nd(2, channel_in, channel_in, kernel_size=3, padding=1)
89
+ )
90
+ self.blocks.append(
91
+ conv_nd(
92
+ 2, channel_in, channel_out, kernel_size=3, padding=1, stride=1
93
+ )
94
+ )
95
+
96
+ if up_sampling == 'transpose':
97
+ self.conv_out = zero_module(
98
+ nn.ConvTranspose2d(
99
+ in_channels=block_out_channels[-1],
100
+ out_channels=conditioning_embedding_channels,
101
+ kernel_size=4,
102
+ stride=2,
103
+ padding=1,
104
+ )
105
+ )
106
+ else:
107
+ self.conv_out = zero_module(conv_nd(dims, block_out_channels[-1], conditioning_embedding_channels, 3, padding=1))
108
+
109
+ def forward(self, conditioning):
110
+
111
+
112
+ embedding = self.conv_in(conditioning)
113
+ embedding = F.silu(embedding)
114
+
115
+
116
+ for block in self.blocks:
117
+ embedding = block(embedding)
118
+ embedding = F.silu(embedding)
119
+
120
+ embedding = self.conv_out(embedding)
121
+
122
+ return embedding
123
+
124
+
125
+ class DINOControlNetVAEModel(ControlNetModel):
126
+ @register_to_config
127
+ def __init__(
128
+ self,
129
+ in_channels: int = 4,
130
+ conditioning_channels: int = 3,
131
+ flip_sin_to_cos: bool = True,
132
+ freq_shift: int = 0,
133
+ down_block_types: Tuple[str, ...] = (
134
+ "CrossAttnDownBlock2D",
135
+ "CrossAttnDownBlock2D",
136
+ "CrossAttnDownBlock2D",
137
+ "DownBlock2D",
138
+ ),
139
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
140
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
141
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
142
+ layers_per_block: int = 2,
143
+ downsample_padding: int = 1,
144
+ mid_block_scale_factor: float = 1,
145
+ act_fn: str = "silu",
146
+ norm_num_groups: Optional[int] = 32,
147
+ norm_eps: float = 1e-5,
148
+ cross_attention_dim: int = 1280,
149
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
150
+ encoder_hid_dim: Optional[int] = None,
151
+ encoder_hid_dim_type: Optional[str] = None,
152
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
153
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
154
+ use_linear_projection: bool = False,
155
+ class_embed_type: Optional[str] = None,
156
+ addition_embed_type: Optional[str] = None,
157
+ addition_time_embed_dim: Optional[int] = None,
158
+ num_class_embeds: Optional[int] = None,
159
+ upcast_attention: bool = False,
160
+ resnet_time_scale_shift: str = "default",
161
+ projection_class_embeddings_input_dim: Optional[int] = None,
162
+ controlnet_conditioning_channel_order: str = "rgb",
163
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
164
+ global_pool_conditions: bool = False,
165
+ addition_embed_type_num_heads: int = 64,
166
+ dino_up_sampling='transpose',
167
+ dino_conditioning_embedding_channels = 320,
168
+ dino_conditioning_channels = 1024,
169
+ dino_block_out_channels = [512, 128, 256, 256],
170
+ ):
171
+ super().__init__(
172
+ in_channels,
173
+ conditioning_channels,
174
+ flip_sin_to_cos,
175
+ freq_shift,
176
+ down_block_types,
177
+ mid_block_type,
178
+ only_cross_attention,
179
+ block_out_channels,
180
+ layers_per_block,
181
+ downsample_padding,
182
+ mid_block_scale_factor,
183
+ act_fn,
184
+ norm_num_groups,
185
+ norm_eps,
186
+ cross_attention_dim,
187
+ transformer_layers_per_block,
188
+ encoder_hid_dim,
189
+ encoder_hid_dim_type,
190
+ attention_head_dim,
191
+ num_attention_heads,
192
+ use_linear_projection,
193
+ class_embed_type,
194
+ addition_embed_type,
195
+ addition_time_embed_dim,
196
+ num_class_embeds,
197
+ upcast_attention,
198
+ resnet_time_scale_shift,
199
+ projection_class_embeddings_input_dim,
200
+ controlnet_conditioning_channel_order,
201
+ conditioning_embedding_out_channels,
202
+ global_pool_conditions,
203
+ addition_embed_type_num_heads,
204
+ )
205
+
206
+
207
+ # dino controlnet embeddings
208
+ self.dino_controlnet_cond_embedding = DINOControlNetConditioningEmbedding(
209
+ up_sampling = dino_up_sampling,
210
+ conditioning_embedding_channels = dino_conditioning_embedding_channels,
211
+ conditioning_channels = dino_conditioning_channels,
212
+ block_out_channels = dino_block_out_channels ,
213
+ )
214
+
215
+
216
+ def forward(
217
+ self,
218
+ sample: torch.Tensor,
219
+ timestep: Union[torch.Tensor, float, int],
220
+ encoder_hidden_states: torch.Tensor,
221
+ controlnet_cond: torch.Tensor = None,
222
+ conditioning_scale: float = 1.0,
223
+ class_labels: Optional[torch.Tensor] = None,
224
+ timestep_cond: Optional[torch.Tensor] = None,
225
+ attention_mask: Optional[torch.Tensor] = None,
226
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
227
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
228
+ guess_mode: bool = False,
229
+ return_dict: bool = True,
230
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
231
+ """
232
+ The [`ControlNetVAEModel`] forward method.
233
+
234
+ Args:
235
+ sample (`torch.Tensor`):
236
+ The noisy input tensor.
237
+ timestep (`Union[torch.Tensor, float, int]`):
238
+ The number of timesteps to denoise an input.
239
+ encoder_hidden_states (`torch.Tensor`):
240
+ The encoder hidden states.
241
+ controlnet_cond (`torch.Tensor`):
242
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
243
+ conditioning_scale (`float`, defaults to `1.0`):
244
+ The scale factor for ControlNet outputs.
245
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
246
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
247
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
248
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
249
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
250
+ embeddings.
251
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
252
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
253
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
254
+ negative values to the attention scores corresponding to "discard" tokens.
255
+ added_cond_kwargs (`dict`):
256
+ Additional conditions for the Stable Diffusion XL UNet.
257
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
258
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
259
+ guess_mode (`bool`, defaults to `False`):
260
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
261
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
262
+ return_dict (`bool`, defaults to `True`):
263
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
264
+
265
+ Returns:
266
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
267
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
268
+ returned where the first element is the sample tensor.
269
+ """
270
+ # check channel order
271
+
272
+
273
+ channel_order = self.config.controlnet_conditioning_channel_order
274
+
275
+ if channel_order == "rgb":
276
+ # in rgb order by default
277
+ ...
278
+ elif channel_order == "bgr":
279
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
280
+ else:
281
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
282
+
283
+ # prepare attention_mask
284
+ if attention_mask is not None:
285
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
286
+ attention_mask = attention_mask.unsqueeze(1)
287
+
288
+ # 1. time
289
+ timesteps = timestep
290
+ if not torch.is_tensor(timesteps):
291
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
292
+ # This would be a good case for the `match` statement (Python 3.10+)
293
+ is_mps = sample.device.type == "mps"
294
+ if isinstance(timestep, float):
295
+ dtype = torch.float32 if is_mps else torch.float64
296
+ else:
297
+ dtype = torch.int32 if is_mps else torch.int64
298
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
299
+ elif len(timesteps.shape) == 0:
300
+ timesteps = timesteps[None].to(sample.device)
301
+
302
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
303
+ timesteps = timesteps.expand(sample.shape[0])
304
+
305
+ t_emb = self.time_proj(timesteps)
306
+
307
+ # timesteps does not contain any weights and will always return f32 tensors
308
+ # but time_embedding might actually be running in fp16. so we need to cast here.
309
+ # there might be better ways to encapsulate this.
310
+ t_emb = t_emb.to(dtype=sample.dtype)
311
+
312
+ emb = self.time_embedding(t_emb, timestep_cond)
313
+ aug_emb = None
314
+
315
+ if self.class_embedding is not None:
316
+ if class_labels is None:
317
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
318
+
319
+ if self.config.class_embed_type == "timestep":
320
+ class_labels = self.time_proj(class_labels)
321
+
322
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
323
+ emb = emb + class_emb
324
+
325
+ if self.config.addition_embed_type is not None:
326
+ if self.config.addition_embed_type == "text":
327
+ aug_emb = self.add_embedding(encoder_hidden_states)
328
+
329
+ elif self.config.addition_embed_type == "text_time":
330
+ if "text_embeds" not in added_cond_kwargs:
331
+ raise ValueError(
332
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
333
+ )
334
+ text_embeds = added_cond_kwargs.get("text_embeds")
335
+ if "time_ids" not in added_cond_kwargs:
336
+ raise ValueError(
337
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
338
+ )
339
+ time_ids = added_cond_kwargs.get("time_ids")
340
+ time_embeds = self.add_time_proj(time_ids.flatten())
341
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
342
+
343
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
344
+ add_embeds = add_embeds.to(emb.dtype)
345
+ aug_emb = self.add_embedding(add_embeds)
346
+
347
+
348
+ emb = emb + aug_emb if aug_emb is not None else emb
349
+ # 2. pre-process
350
+ # sample = self.conv_in(sample) # without input_blocks[0]
351
+
352
+ # 3. down
353
+ down_block_res_samples = (sample,)
354
+ for downsample_block in self.down_blocks:
355
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
356
+ sample, res_samples = downsample_block(
357
+ hidden_states=sample,
358
+ temb=emb,
359
+ encoder_hidden_states=encoder_hidden_states,
360
+ attention_mask=attention_mask,
361
+ cross_attention_kwargs=cross_attention_kwargs,
362
+ )
363
+ else:
364
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
365
+
366
+ down_block_res_samples += res_samples
367
+
368
+ # 5. Control net blocks
369
+ # dino features without zero conv
370
+ controlnet_down_block_res_samples = (down_block_res_samples[0], )
371
+
372
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples[1:], self.controlnet_down_blocks[1:]):
373
+ down_block_res_sample = controlnet_block(down_block_res_sample)
374
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
375
+
376
+ down_block_res_samples = controlnet_down_block_res_samples
377
+
378
+
379
+ mid_block_res_sample = None
380
+
381
+ # 6. scaling
382
+ if guess_mode and not self.config.global_pool_conditions:
383
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
384
+ scales = scales * conditioning_scale
385
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
386
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
387
+ else:
388
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
389
+
390
+ if self.config.global_pool_conditions:
391
+ down_block_res_samples = [
392
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
393
+ ]
394
+
395
+ if not return_dict:
396
+ return (down_block_res_samples, mid_block_res_sample)
397
+
398
+ return ControlNetOutput(
399
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
400
+ )
401
+
402
+
403
+