SunderAli17 commited on
Commit
00fb38e
1 Parent(s): 2560b63

Create min_sdxl.py

Browse files
Files changed (1) hide show
  1. module/min_sdxl.py +907 -0
module/min_sdxl.py ADDED
@@ -0,0 +1,907 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from minSDXL by Simo Ryu:
2
+ # https://github.com/cloneofsimo/minSDXL ,
3
+ # which is in turn modified from the original code of:
4
+ # https://github.com/huggingface/diffusers
5
+ # So has APACHE 2.0 license
6
+
7
+ from typing import Optional, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import math
13
+ import inspect
14
+
15
+ from collections import namedtuple
16
+
17
+ from torch.fft import fftn, fftshift, ifftn, ifftshift
18
+
19
+ from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
20
+
21
+ # Implementation of FreeU for minsdxl
22
+
23
+ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
24
+ """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
25
+ This version of the method comes from here:
26
+ https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706
27
+ """
28
+ x = x_in
29
+ B, C, H, W = x.shape
30
+
31
+ # Non-power of 2 images must be float32
32
+ if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
33
+ x = x.to(dtype=torch.float32)
34
+
35
+ # FFT
36
+ x_freq = fftn(x, dim=(-2, -1))
37
+ x_freq = fftshift(x_freq, dim=(-2, -1))
38
+
39
+ B, C, H, W = x_freq.shape
40
+ mask = torch.ones((B, C, H, W), device=x.device)
41
+
42
+ crow, ccol = H // 2, W // 2
43
+ mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale
44
+ x_freq = x_freq * mask
45
+
46
+ # IFFT
47
+ x_freq = ifftshift(x_freq, dim=(-2, -1))
48
+ x_filtered = ifftn(x_freq, dim=(-2, -1)).real
49
+
50
+ return x_filtered.to(dtype=x_in.dtype)
51
+
52
+
53
+ def apply_freeu(
54
+ resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs):
55
+ """Applies the FreeU mechanism as introduced in https:
56
+ //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU.
57
+ Args:
58
+ resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied.
59
+ hidden_states (`torch.Tensor`): Inputs to the underlying block.
60
+ res_hidden_states (`torch.Tensor`): Features from the skip block corresponding to the underlying block.
61
+ s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features.
62
+ s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features.
63
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
64
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
65
+ """
66
+ if resolution_idx == 0:
67
+ num_half_channels = hidden_states.shape[1] // 2
68
+ hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b1"]
69
+ res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s1"])
70
+ if resolution_idx == 1:
71
+ num_half_channels = hidden_states.shape[1] // 2
72
+ hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b2"]
73
+ res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"])
74
+
75
+ return hidden_states, res_hidden_states
76
+
77
+ # Diffusers-style LoRA to keep everything in the min_sdxl.py file
78
+
79
+ class LoRALinearLayer(nn.Module):
80
+ r"""
81
+ A linear layer that is used with LoRA.
82
+ Parameters:
83
+ in_features (`int`):
84
+ Number of input features.
85
+ out_features (`int`):
86
+ Number of output features.
87
+ rank (`int`, `optional`, defaults to 4):
88
+ The rank of the LoRA layer.
89
+ network_alpha (`float`, `optional`, defaults to `None`):
90
+ The value of the network alpha used for stable learning and preventing underflow. This value has the same
91
+ meaning as the `--network_alpha` option in the kohya-ss trainer script. See
92
+ https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
93
+ device (`torch.device`, `optional`, defaults to `None`):
94
+ The device to use for the layer's weights.
95
+ dtype (`torch.dtype`, `optional`, defaults to `None`):
96
+ The dtype to use for the layer's weights.
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ in_features: int,
102
+ out_features: int,
103
+ rank: int = 4,
104
+ network_alpha: Optional[float] = None,
105
+ device: Optional[Union[torch.device, str]] = None,
106
+ dtype: Optional[torch.dtype] = None,
107
+ ):
108
+ super().__init__()
109
+
110
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
111
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
112
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
113
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
114
+ self.network_alpha = network_alpha
115
+ self.rank = rank
116
+ self.out_features = out_features
117
+ self.in_features = in_features
118
+
119
+ nn.init.normal_(self.down.weight, std=1 / rank)
120
+ nn.init.zeros_(self.up.weight)
121
+
122
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
123
+ orig_dtype = hidden_states.dtype
124
+ dtype = self.down.weight.dtype
125
+
126
+ down_hidden_states = self.down(hidden_states.to(dtype))
127
+ up_hidden_states = self.up(down_hidden_states)
128
+
129
+ if self.network_alpha is not None:
130
+ up_hidden_states *= self.network_alpha / self.rank
131
+
132
+ return up_hidden_states.to(orig_dtype)
133
+
134
+ class LoRACompatibleLinear(nn.Linear):
135
+ """
136
+ A Linear layer that can be used with LoRA.
137
+ """
138
+
139
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
140
+ super().__init__(*args, **kwargs)
141
+ self.lora_layer = lora_layer
142
+
143
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
144
+ self.lora_layer = lora_layer
145
+
146
+ def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
147
+ if self.lora_layer is None:
148
+ return
149
+
150
+ dtype, device = self.weight.data.dtype, self.weight.data.device
151
+
152
+ w_orig = self.weight.data.float()
153
+ w_up = self.lora_layer.up.weight.data.float()
154
+ w_down = self.lora_layer.down.weight.data.float()
155
+
156
+ if self.lora_layer.network_alpha is not None:
157
+ w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
158
+
159
+ fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
160
+
161
+ if safe_fusing and torch.isnan(fused_weight).any().item():
162
+ raise ValueError(
163
+ "This LoRA weight seems to be broken. "
164
+ f"Encountered NaN values when trying to fuse LoRA weights for {self}."
165
+ "LoRA weights will not be fused."
166
+ )
167
+
168
+ self.weight.data = fused_weight.to(device=device, dtype=dtype)
169
+
170
+ # we can drop the lora layer now
171
+ self.lora_layer = None
172
+
173
+ # offload the up and down matrices to CPU to not blow the memory
174
+ self.w_up = w_up.cpu()
175
+ self.w_down = w_down.cpu()
176
+ self._lora_scale = lora_scale
177
+
178
+ def _unfuse_lora(self):
179
+ if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
180
+ return
181
+
182
+ fused_weight = self.weight.data
183
+ dtype, device = fused_weight.dtype, fused_weight.device
184
+
185
+ w_up = self.w_up.to(device=device).float()
186
+ w_down = self.w_down.to(device).float()
187
+
188
+ unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
189
+ self.weight.data = unfused_weight.to(device=device, dtype=dtype)
190
+
191
+ self.w_up = None
192
+ self.w_down = None
193
+
194
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
195
+ if self.lora_layer is None:
196
+ out = super().forward(hidden_states)
197
+ return out
198
+ else:
199
+ out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
200
+ return out
201
+
202
+ class Timesteps(nn.Module):
203
+ def __init__(self, num_channels: int = 320):
204
+ super().__init__()
205
+ self.num_channels = num_channels
206
+
207
+ def forward(self, timesteps):
208
+ half_dim = self.num_channels // 2
209
+ exponent = -math.log(10000) * torch.arange(
210
+ half_dim, dtype=torch.float32, device=timesteps.device
211
+ )
212
+ exponent = exponent / (half_dim - 0.0)
213
+
214
+ emb = torch.exp(exponent)
215
+ emb = timesteps[:, None].float() * emb[None, :]
216
+
217
+ sin_emb = torch.sin(emb)
218
+ cos_emb = torch.cos(emb)
219
+ emb = torch.cat([cos_emb, sin_emb], dim=-1)
220
+
221
+ return emb
222
+
223
+
224
+ class TimestepEmbedding(nn.Module):
225
+ def __init__(self, in_features, out_features):
226
+ super(TimestepEmbedding, self).__init__()
227
+ self.linear_1 = nn.Linear(in_features, out_features, bias=True)
228
+ self.act = nn.SiLU()
229
+ self.linear_2 = nn.Linear(out_features, out_features, bias=True)
230
+
231
+ def forward(self, sample):
232
+ sample = self.linear_1(sample)
233
+ sample = self.act(sample)
234
+ sample = self.linear_2(sample)
235
+
236
+ return sample
237
+
238
+
239
+ class ResnetBlock2D(nn.Module):
240
+ def __init__(self, in_channels, out_channels, conv_shortcut=True):
241
+ super(ResnetBlock2D, self).__init__()
242
+ self.norm1 = nn.GroupNorm(32, in_channels, eps=1e-05, affine=True)
243
+ self.conv1 = nn.Conv2d(
244
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
245
+ )
246
+ self.time_emb_proj = nn.Linear(1280, out_channels, bias=True)
247
+ self.norm2 = nn.GroupNorm(32, out_channels, eps=1e-05, affine=True)
248
+ self.dropout = nn.Dropout(p=0.0, inplace=False)
249
+ self.conv2 = nn.Conv2d(
250
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
251
+ )
252
+ self.nonlinearity = nn.SiLU()
253
+ self.conv_shortcut = None
254
+ if conv_shortcut:
255
+ self.conv_shortcut = nn.Conv2d(
256
+ in_channels, out_channels, kernel_size=1, stride=1
257
+ )
258
+
259
+ def forward(self, input_tensor, temb):
260
+ hidden_states = input_tensor
261
+ hidden_states = self.norm1(hidden_states)
262
+ hidden_states = self.nonlinearity(hidden_states)
263
+
264
+ hidden_states = self.conv1(hidden_states)
265
+
266
+ temb = self.nonlinearity(temb)
267
+ temb = self.time_emb_proj(temb)[:, :, None, None]
268
+ hidden_states = hidden_states + temb
269
+ hidden_states = self.norm2(hidden_states)
270
+
271
+ hidden_states = self.nonlinearity(hidden_states)
272
+ hidden_states = self.dropout(hidden_states)
273
+ hidden_states = self.conv2(hidden_states)
274
+
275
+ if self.conv_shortcut is not None:
276
+ input_tensor = self.conv_shortcut(input_tensor)
277
+
278
+ output_tensor = input_tensor + hidden_states
279
+
280
+ return output_tensor
281
+
282
+
283
+ class Attention(nn.Module):
284
+ def __init__(
285
+ self, inner_dim, cross_attention_dim=None, num_heads=None, dropout=0.0, processor=None, scale_qk=True
286
+ ):
287
+ super(Attention, self).__init__()
288
+ if num_heads is None:
289
+ self.head_dim = 64
290
+ self.num_heads = inner_dim // self.head_dim
291
+ else:
292
+ self.num_heads = num_heads
293
+ self.head_dim = inner_dim // num_heads
294
+
295
+ self.scale = self.head_dim**-0.5
296
+ if cross_attention_dim is None:
297
+ cross_attention_dim = inner_dim
298
+ self.to_q = LoRACompatibleLinear(inner_dim, inner_dim, bias=False)
299
+ self.to_k = LoRACompatibleLinear(cross_attention_dim, inner_dim, bias=False)
300
+ self.to_v = LoRACompatibleLinear(cross_attention_dim, inner_dim, bias=False)
301
+
302
+ self.to_out = nn.ModuleList(
303
+ [LoRACompatibleLinear(inner_dim, inner_dim), nn.Dropout(dropout, inplace=False)]
304
+ )
305
+
306
+ self.scale_qk = scale_qk
307
+ if processor is None:
308
+ processor = (
309
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
310
+ )
311
+ self.set_processor(processor)
312
+
313
+ def forward(
314
+ self,
315
+ hidden_states: torch.FloatTensor,
316
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
317
+ attention_mask: Optional[torch.FloatTensor] = None,
318
+ **cross_attention_kwargs,
319
+ ) -> torch.Tensor:
320
+ r"""
321
+ The forward method of the `Attention` class.
322
+ Args:
323
+ hidden_states (`torch.Tensor`):
324
+ The hidden states of the query.
325
+ encoder_hidden_states (`torch.Tensor`, *optional*):
326
+ The hidden states of the encoder.
327
+ attention_mask (`torch.Tensor`, *optional*):
328
+ The attention mask to use. If `None`, no mask is applied.
329
+ **cross_attention_kwargs:
330
+ Additional keyword arguments to pass along to the cross attention.
331
+ Returns:
332
+ `torch.Tensor`: The output of the attention layer.
333
+ """
334
+ # The `Attention` class can call different attention processors / attention functions
335
+ # here we simply pass along all tensors to the selected processor class
336
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
337
+
338
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
339
+ unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
340
+ if len(unused_kwargs) > 0:
341
+ print(
342
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
343
+ )
344
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
345
+
346
+ return self.processor(
347
+ self,
348
+ hidden_states,
349
+ encoder_hidden_states=encoder_hidden_states,
350
+ attention_mask=attention_mask,
351
+ **cross_attention_kwargs,
352
+ )
353
+
354
+ def orig_forward(self, hidden_states, encoder_hidden_states=None):
355
+ q = self.to_q(hidden_states)
356
+ k = (
357
+ self.to_k(encoder_hidden_states)
358
+ if encoder_hidden_states is not None
359
+ else self.to_k(hidden_states)
360
+ )
361
+ v = (
362
+ self.to_v(encoder_hidden_states)
363
+ if encoder_hidden_states is not None
364
+ else self.to_v(hidden_states)
365
+ )
366
+ b, t, c = q.size()
367
+
368
+ q = q.view(q.size(0), q.size(1), self.num_heads, self.head_dim).transpose(1, 2)
369
+ k = k.view(k.size(0), k.size(1), self.num_heads, self.head_dim).transpose(1, 2)
370
+ v = v.view(v.size(0), v.size(1), self.num_heads, self.head_dim).transpose(1, 2)
371
+
372
+ # scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
373
+ # attn_weights = torch.softmax(scores, dim=-1)
374
+ # attn_output = torch.matmul(attn_weights, v)
375
+
376
+ attn_output = F.scaled_dot_product_attention(
377
+ q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale,
378
+ )
379
+
380
+ attn_output = attn_output.transpose(1, 2).contiguous().view(b, t, c)
381
+
382
+ for layer in self.to_out:
383
+ attn_output = layer(attn_output)
384
+
385
+ return attn_output
386
+
387
+ def set_processor(self, processor) -> None:
388
+ r"""
389
+ Set the attention processor to use.
390
+ Args:
391
+ processor (`AttnProcessor`):
392
+ The attention processor to use.
393
+ """
394
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
395
+ # pop `processor` from `self._modules`
396
+ if (
397
+ hasattr(self, "processor")
398
+ and isinstance(self.processor, torch.nn.Module)
399
+ and not isinstance(processor, torch.nn.Module)
400
+ ):
401
+ print(f"You are removing possibly trained weights of {self.processor} with {processor}")
402
+ self._modules.pop("processor")
403
+
404
+ self.processor = processor
405
+
406
+ def get_processor(self, return_deprecated_lora: bool = False):
407
+ r"""
408
+ Get the attention processor in use.
409
+ Args:
410
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
411
+ Set to `True` to return the deprecated LoRA attention processor.
412
+ Returns:
413
+ "AttentionProcessor": The attention processor in use.
414
+ """
415
+ if not return_deprecated_lora:
416
+ return self.processor
417
+
418
+ # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
419
+ # serialization format for LoRA Attention Processors. It should be deleted once the integration
420
+ # with PEFT is completed.
421
+ is_lora_activated = {
422
+ name: module.lora_layer is not None
423
+ for name, module in self.named_modules()
424
+ if hasattr(module, "lora_layer")
425
+ }
426
+
427
+ # 1. if no layer has a LoRA activated we can return the processor as usual
428
+ if not any(is_lora_activated.values()):
429
+ return self.processor
430
+
431
+ # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
432
+ is_lora_activated.pop("add_k_proj", None)
433
+ is_lora_activated.pop("add_v_proj", None)
434
+ # 2. else it is not possible that only some layers have LoRA activated
435
+ if not all(is_lora_activated.values()):
436
+ raise ValueError(
437
+ f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
438
+ )
439
+
440
+ # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
441
+ non_lora_processor_cls_name = self.processor.__class__.__name__
442
+ lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
443
+
444
+ hidden_size = self.inner_dim
445
+
446
+ # now create a LoRA attention processor from the LoRA layers
447
+ if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
448
+ kwargs = {
449
+ "cross_attention_dim": self.cross_attention_dim,
450
+ "rank": self.to_q.lora_layer.rank,
451
+ "network_alpha": self.to_q.lora_layer.network_alpha,
452
+ "q_rank": self.to_q.lora_layer.rank,
453
+ "q_hidden_size": self.to_q.lora_layer.out_features,
454
+ "k_rank": self.to_k.lora_layer.rank,
455
+ "k_hidden_size": self.to_k.lora_layer.out_features,
456
+ "v_rank": self.to_v.lora_layer.rank,
457
+ "v_hidden_size": self.to_v.lora_layer.out_features,
458
+ "out_rank": self.to_out[0].lora_layer.rank,
459
+ "out_hidden_size": self.to_out[0].lora_layer.out_features,
460
+ }
461
+
462
+ if hasattr(self.processor, "attention_op"):
463
+ kwargs["attention_op"] = self.processor.attention_op
464
+
465
+ lora_processor = lora_processor_cls(hidden_size, **kwargs)
466
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
467
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
468
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
469
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
470
+ elif lora_processor_cls == LoRAAttnAddedKVProcessor:
471
+ lora_processor = lora_processor_cls(
472
+ hidden_size,
473
+ cross_attention_dim=self.add_k_proj.weight.shape[0],
474
+ rank=self.to_q.lora_layer.rank,
475
+ network_alpha=self.to_q.lora_layer.network_alpha,
476
+ )
477
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
478
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
479
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
480
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
481
+
482
+ # only save if used
483
+ if self.add_k_proj.lora_layer is not None:
484
+ lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
485
+ lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
486
+ else:
487
+ lora_processor.add_k_proj_lora = None
488
+ lora_processor.add_v_proj_lora = None
489
+ else:
490
+ raise ValueError(f"{lora_processor_cls} does not exist.")
491
+
492
+ return lora_processor
493
+
494
+ class GEGLU(nn.Module):
495
+ def __init__(self, in_features, out_features):
496
+ super(GEGLU, self).__init__()
497
+ self.proj = nn.Linear(in_features, out_features * 2, bias=True)
498
+
499
+ def forward(self, x):
500
+ x_proj = self.proj(x)
501
+ x1, x2 = x_proj.chunk(2, dim=-1)
502
+ return x1 * torch.nn.functional.gelu(x2)
503
+
504
+
505
+ class FeedForward(nn.Module):
506
+ def __init__(self, in_features, out_features):
507
+ super(FeedForward, self).__init__()
508
+
509
+ self.net = nn.ModuleList(
510
+ [
511
+ GEGLU(in_features, out_features * 4),
512
+ nn.Dropout(p=0.0, inplace=False),
513
+ nn.Linear(out_features * 4, out_features, bias=True),
514
+ ]
515
+ )
516
+
517
+ def forward(self, x):
518
+ for layer in self.net:
519
+ x = layer(x)
520
+ return x
521
+
522
+
523
+ class BasicTransformerBlock(nn.Module):
524
+ def __init__(self, hidden_size):
525
+ super(BasicTransformerBlock, self).__init__()
526
+ self.norm1 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True)
527
+ self.attn1 = Attention(hidden_size)
528
+ self.norm2 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True)
529
+ self.attn2 = Attention(hidden_size, 2048)
530
+ self.norm3 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True)
531
+ self.ff = FeedForward(hidden_size, hidden_size)
532
+
533
+ def forward(self, x, encoder_hidden_states=None):
534
+ residual = x
535
+
536
+ x = self.norm1(x)
537
+ x = self.attn1(x)
538
+ x = x + residual
539
+
540
+ residual = x
541
+
542
+ x = self.norm2(x)
543
+ if encoder_hidden_states is not None:
544
+ x = self.attn2(x, encoder_hidden_states)
545
+ else:
546
+ x = self.attn2(x)
547
+ x = x + residual
548
+
549
+ residual = x
550
+
551
+ x = self.norm3(x)
552
+ x = self.ff(x)
553
+ x = x + residual
554
+ return x
555
+
556
+
557
+ class Transformer2DModel(nn.Module):
558
+ def __init__(self, in_channels, out_channels, n_layers):
559
+ super(Transformer2DModel, self).__init__()
560
+ self.norm = nn.GroupNorm(32, in_channels, eps=1e-06, affine=True)
561
+ self.proj_in = nn.Linear(in_channels, out_channels, bias=True)
562
+ self.transformer_blocks = nn.ModuleList(
563
+ [BasicTransformerBlock(out_channels) for _ in range(n_layers)]
564
+ )
565
+ self.proj_out = nn.Linear(out_channels, out_channels, bias=True)
566
+
567
+ def forward(self, hidden_states, encoder_hidden_states=None):
568
+ batch, _, height, width = hidden_states.shape
569
+ res = hidden_states
570
+ hidden_states = self.norm(hidden_states)
571
+ inner_dim = hidden_states.shape[1]
572
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
573
+ batch, height * width, inner_dim
574
+ )
575
+ hidden_states = self.proj_in(hidden_states)
576
+
577
+ for block in self.transformer_blocks:
578
+ hidden_states = block(hidden_states, encoder_hidden_states)
579
+
580
+ hidden_states = self.proj_out(hidden_states)
581
+ hidden_states = (
582
+ hidden_states.reshape(batch, height, width, inner_dim)
583
+ .permute(0, 3, 1, 2)
584
+ .contiguous()
585
+ )
586
+
587
+ return hidden_states + res
588
+
589
+
590
+ class Downsample2D(nn.Module):
591
+ def __init__(self, in_channels, out_channels):
592
+ super(Downsample2D, self).__init__()
593
+ self.conv = nn.Conv2d(
594
+ in_channels, out_channels, kernel_size=3, stride=2, padding=1
595
+ )
596
+
597
+ def forward(self, x):
598
+ return self.conv(x)
599
+
600
+
601
+ class Upsample2D(nn.Module):
602
+ def __init__(self, in_channels, out_channels):
603
+ super(Upsample2D, self).__init__()
604
+ self.conv = nn.Conv2d(
605
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
606
+ )
607
+
608
+ def forward(self, x):
609
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
610
+ return self.conv(x)
611
+
612
+
613
+ class DownBlock2D(nn.Module):
614
+ def __init__(self, in_channels, out_channels):
615
+ super(DownBlock2D, self).__init__()
616
+ self.resnets = nn.ModuleList(
617
+ [
618
+ ResnetBlock2D(in_channels, out_channels, conv_shortcut=False),
619
+ ResnetBlock2D(out_channels, out_channels, conv_shortcut=False),
620
+ ]
621
+ )
622
+ self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)])
623
+
624
+ def forward(self, hidden_states, temb):
625
+ output_states = []
626
+ for module in self.resnets:
627
+ hidden_states = module(hidden_states, temb)
628
+ output_states.append(hidden_states)
629
+
630
+ hidden_states = self.downsamplers[0](hidden_states)
631
+ output_states.append(hidden_states)
632
+
633
+ return hidden_states, output_states
634
+
635
+
636
+ class CrossAttnDownBlock2D(nn.Module):
637
+ def __init__(self, in_channels, out_channels, n_layers, has_downsamplers=True):
638
+ super(CrossAttnDownBlock2D, self).__init__()
639
+ self.attentions = nn.ModuleList(
640
+ [
641
+ Transformer2DModel(out_channels, out_channels, n_layers),
642
+ Transformer2DModel(out_channels, out_channels, n_layers),
643
+ ]
644
+ )
645
+ self.resnets = nn.ModuleList(
646
+ [
647
+ ResnetBlock2D(in_channels, out_channels),
648
+ ResnetBlock2D(out_channels, out_channels, conv_shortcut=False),
649
+ ]
650
+ )
651
+ self.downsamplers = None
652
+ if has_downsamplers:
653
+ self.downsamplers = nn.ModuleList(
654
+ [Downsample2D(out_channels, out_channels)]
655
+ )
656
+
657
+ def forward(self, hidden_states, temb, encoder_hidden_states):
658
+ output_states = []
659
+ for resnet, attn in zip(self.resnets, self.attentions):
660
+ hidden_states = resnet(hidden_states, temb)
661
+ hidden_states = attn(
662
+ hidden_states,
663
+ encoder_hidden_states=encoder_hidden_states,
664
+ )
665
+ output_states.append(hidden_states)
666
+
667
+ if self.downsamplers is not None:
668
+ hidden_states = self.downsamplers[0](hidden_states)
669
+ output_states.append(hidden_states)
670
+
671
+ return hidden_states, output_states
672
+
673
+
674
+ class CrossAttnUpBlock2D(nn.Module):
675
+ def __init__(self, in_channels, out_channels, prev_output_channel, n_layers):
676
+ super(CrossAttnUpBlock2D, self).__init__()
677
+ self.attentions = nn.ModuleList(
678
+ [
679
+ Transformer2DModel(out_channels, out_channels, n_layers),
680
+ Transformer2DModel(out_channels, out_channels, n_layers),
681
+ Transformer2DModel(out_channels, out_channels, n_layers),
682
+ ]
683
+ )
684
+ self.resnets = nn.ModuleList(
685
+ [
686
+ ResnetBlock2D(prev_output_channel + out_channels, out_channels),
687
+ ResnetBlock2D(2 * out_channels, out_channels),
688
+ ResnetBlock2D(out_channels + in_channels, out_channels),
689
+ ]
690
+ )
691
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
692
+
693
+ def forward(
694
+ self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states
695
+ ):
696
+ for resnet, attn in zip(self.resnets, self.attentions):
697
+ # pop res hidden states
698
+ res_hidden_states = res_hidden_states_tuple[-1]
699
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
700
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
701
+ hidden_states = resnet(hidden_states, temb)
702
+ hidden_states = attn(
703
+ hidden_states,
704
+ encoder_hidden_states=encoder_hidden_states,
705
+ )
706
+
707
+ if self.upsamplers is not None:
708
+ for upsampler in self.upsamplers:
709
+ hidden_states = upsampler(hidden_states)
710
+
711
+ return hidden_states
712
+
713
+
714
+ class UpBlock2D(nn.Module):
715
+ def __init__(self, in_channels, out_channels, prev_output_channel):
716
+ super(UpBlock2D, self).__init__()
717
+ self.resnets = nn.ModuleList(
718
+ [
719
+ ResnetBlock2D(out_channels + prev_output_channel, out_channels),
720
+ ResnetBlock2D(out_channels * 2, out_channels),
721
+ ResnetBlock2D(out_channels + in_channels, out_channels),
722
+ ]
723
+ )
724
+
725
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
726
+
727
+ is_freeu_enabled = (
728
+ getattr(self, "s1", None)
729
+ and getattr(self, "s2", None)
730
+ and getattr(self, "b1", None)
731
+ and getattr(self, "b2", None)
732
+ and getattr(self, "resolution_idx", None)
733
+ )
734
+
735
+ for resnet in self.resnets:
736
+ res_hidden_states = res_hidden_states_tuple[-1]
737
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
738
+
739
+
740
+ if is_freeu_enabled:
741
+ hidden_states, res_hidden_states = apply_freeu(
742
+ self.resolution_idx,
743
+ hidden_states,
744
+ res_hidden_states,
745
+ s1=self.s1,
746
+ s2=self.s2,
747
+ b1=self.b1,
748
+ b2=self.b2,
749
+ )
750
+
751
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
752
+ hidden_states = resnet(hidden_states, temb)
753
+
754
+ return hidden_states
755
+
756
+ class UNetMidBlock2DCrossAttn(nn.Module):
757
+ def __init__(self, in_features):
758
+ super(UNetMidBlock2DCrossAttn, self).__init__()
759
+ self.attentions = nn.ModuleList(
760
+ [Transformer2DModel(in_features, in_features, n_layers=10)]
761
+ )
762
+ self.resnets = nn.ModuleList(
763
+ [
764
+ ResnetBlock2D(in_features, in_features, conv_shortcut=False),
765
+ ResnetBlock2D(in_features, in_features, conv_shortcut=False),
766
+ ]
767
+ )
768
+
769
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
770
+ hidden_states = self.resnets[0](hidden_states, temb)
771
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
772
+ hidden_states = attn(
773
+ hidden_states,
774
+ encoder_hidden_states=encoder_hidden_states,
775
+ )
776
+ hidden_states = resnet(hidden_states, temb)
777
+
778
+ return hidden_states
779
+
780
+
781
+ class UNet2DConditionModel(nn.Module):
782
+ def __init__(self):
783
+ super(UNet2DConditionModel, self).__init__()
784
+
785
+ # This is needed to imitate huggingface config behavior
786
+ # has nothing to do with the model itself
787
+ # remove this if you don't use diffuser's pipeline
788
+ self.config = namedtuple(
789
+ "config", "in_channels addition_time_embed_dim sample_size"
790
+ )
791
+ self.config.in_channels = 4
792
+ self.config.addition_time_embed_dim = 256
793
+ self.config.sample_size = 128
794
+
795
+ self.conv_in = nn.Conv2d(4, 320, kernel_size=3, stride=1, padding=1)
796
+ self.time_proj = Timesteps()
797
+ self.time_embedding = TimestepEmbedding(in_features=320, out_features=1280)
798
+ self.add_time_proj = Timesteps(256)
799
+ self.add_embedding = TimestepEmbedding(in_features=2816, out_features=1280)
800
+ self.down_blocks = nn.ModuleList(
801
+ [
802
+ DownBlock2D(in_channels=320, out_channels=320),
803
+ CrossAttnDownBlock2D(in_channels=320, out_channels=640, n_layers=2),
804
+ CrossAttnDownBlock2D(
805
+ in_channels=640,
806
+ out_channels=1280,
807
+ n_layers=10,
808
+ has_downsamplers=False,
809
+ ),
810
+ ]
811
+ )
812
+ self.up_blocks = nn.ModuleList(
813
+ [
814
+ CrossAttnUpBlock2D(
815
+ in_channels=640,
816
+ out_channels=1280,
817
+ prev_output_channel=1280,
818
+ n_layers=10,
819
+ ),
820
+ CrossAttnUpBlock2D(
821
+ in_channels=320,
822
+ out_channels=640,
823
+ prev_output_channel=1280,
824
+ n_layers=2,
825
+ ),
826
+ UpBlock2D(in_channels=320, out_channels=320, prev_output_channel=640),
827
+ ]
828
+ )
829
+ self.mid_block = UNetMidBlock2DCrossAttn(1280)
830
+ self.conv_norm_out = nn.GroupNorm(32, 320, eps=1e-05, affine=True)
831
+ self.conv_act = nn.SiLU()
832
+ self.conv_out = nn.Conv2d(320, 4, kernel_size=3, stride=1, padding=1)
833
+
834
+ def forward(
835
+ self, sample, timesteps, encoder_hidden_states, added_cond_kwargs, **kwargs
836
+ ):
837
+ # Implement the forward pass through the model
838
+ timesteps = timesteps.expand(sample.shape[0])
839
+ t_emb = self.time_proj(timesteps).to(dtype=sample.dtype)
840
+
841
+ emb = self.time_embedding(t_emb)
842
+
843
+ text_embeds = added_cond_kwargs.get("text_embeds")
844
+ time_ids = added_cond_kwargs.get("time_ids")
845
+
846
+ time_embeds = self.add_time_proj(time_ids.flatten())
847
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
848
+
849
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
850
+ add_embeds = add_embeds.to(emb.dtype)
851
+ aug_emb = self.add_embedding(add_embeds)
852
+
853
+ emb = emb + aug_emb
854
+
855
+ sample = self.conv_in(sample)
856
+
857
+ # 3. down
858
+ s0 = sample
859
+ sample, [s1, s2, s3] = self.down_blocks[0](
860
+ sample,
861
+ temb=emb,
862
+ )
863
+
864
+ sample, [s4, s5, s6] = self.down_blocks[1](
865
+ sample,
866
+ temb=emb,
867
+ encoder_hidden_states=encoder_hidden_states,
868
+ )
869
+
870
+ sample, [s7, s8] = self.down_blocks[2](
871
+ sample,
872
+ temb=emb,
873
+ encoder_hidden_states=encoder_hidden_states,
874
+ )
875
+
876
+ # 4. mid
877
+ sample = self.mid_block(
878
+ sample, emb, encoder_hidden_states=encoder_hidden_states
879
+ )
880
+
881
+ # 5. up
882
+ sample = self.up_blocks[0](
883
+ hidden_states=sample,
884
+ temb=emb,
885
+ res_hidden_states_tuple=[s6, s7, s8],
886
+ encoder_hidden_states=encoder_hidden_states,
887
+ )
888
+
889
+ sample = self.up_blocks[1](
890
+ hidden_states=sample,
891
+ temb=emb,
892
+ res_hidden_states_tuple=[s3, s4, s5],
893
+ encoder_hidden_states=encoder_hidden_states,
894
+ )
895
+
896
+ sample = self.up_blocks[2](
897
+ hidden_states=sample,
898
+ temb=emb,
899
+ res_hidden_states_tuple=[s0, s1, s2],
900
+ )
901
+
902
+ # 6. post-process
903
+ sample = self.conv_norm_out(sample)
904
+ sample = self.conv_act(sample)
905
+ sample = self.conv_out(sample)
906
+
907
+ return [sample]