liuhuijie03 commited on
Commit
891b94e
·
1 Parent(s): d9af839
Files changed (38) hide show
  1. .gitignore +0 -1
  2. app.py +1 -2
  3. lakonlab/models/architecture/__init__.py +1 -0
  4. lakonlab/models/architecture/diffusers/__init__.py +15 -0
  5. lakonlab/models/architecture/diffusers/__pycache__/__init__.cpython-310.pyc +0 -0
  6. lakonlab/models/architecture/diffusers/__pycache__/dit.cpython-310.pyc +0 -0
  7. lakonlab/models/architecture/diffusers/__pycache__/flux.cpython-310.pyc +0 -0
  8. lakonlab/models/architecture/diffusers/__pycache__/pretrained.cpython-310.pyc +0 -0
  9. lakonlab/models/architecture/diffusers/__pycache__/qwen.cpython-310.pyc +0 -0
  10. lakonlab/models/architecture/diffusers/__pycache__/sd3.cpython-310.pyc +0 -0
  11. lakonlab/models/architecture/diffusers/__pycache__/unet.cpython-310.pyc +0 -0
  12. lakonlab/models/architecture/diffusers/dit.py +428 -0
  13. lakonlab/models/architecture/diffusers/flux.py +156 -0
  14. lakonlab/models/architecture/diffusers/pretrained.py +281 -0
  15. lakonlab/models/architecture/diffusers/qwen.py +139 -0
  16. lakonlab/models/architecture/diffusers/sd3.py +80 -0
  17. lakonlab/models/architecture/diffusers/unet.py +192 -0
  18. models/lakonlab/models/architecture/diffusers/__init__.py +15 -0
  19. models/lakonlab/models/architecture/diffusers/__pycache__/__init__.cpython-310.pyc +0 -0
  20. models/lakonlab/models/architecture/diffusers/__pycache__/dit.cpython-310.pyc +0 -0
  21. models/lakonlab/models/architecture/diffusers/__pycache__/flux.cpython-310.pyc +0 -0
  22. models/lakonlab/models/architecture/diffusers/__pycache__/pretrained.cpython-310.pyc +0 -0
  23. models/lakonlab/models/architecture/diffusers/__pycache__/qwen.cpython-310.pyc +0 -0
  24. models/lakonlab/models/architecture/diffusers/__pycache__/sd3.cpython-310.pyc +0 -0
  25. models/lakonlab/models/architecture/diffusers/__pycache__/unet.cpython-310.pyc +0 -0
  26. models/lakonlab/models/architecture/diffusers/dit.py +428 -0
  27. models/lakonlab/models/architecture/diffusers/flux.py +156 -0
  28. models/lakonlab/models/architecture/diffusers/pretrained.py +281 -0
  29. models/lakonlab/models/architecture/diffusers/qwen.py +139 -0
  30. models/lakonlab/models/architecture/diffusers/sd3.py +80 -0
  31. models/lakonlab/models/architecture/diffusers/unet.py +192 -0
  32. piFlow/lakonlab/models/architecture/diffusers/__init__.py +15 -0
  33. piFlow/lakonlab/models/architecture/diffusers/dit.py +428 -0
  34. piFlow/lakonlab/models/architecture/diffusers/flux.py +156 -0
  35. piFlow/lakonlab/models/architecture/diffusers/pretrained.py +281 -0
  36. piFlow/lakonlab/models/architecture/diffusers/qwen.py +139 -0
  37. piFlow/lakonlab/models/architecture/diffusers/sd3.py +80 -0
  38. piFlow/lakonlab/models/architecture/diffusers/unet.py +192 -0
.gitignore CHANGED
@@ -1,3 +1,2 @@
1
  tmp.png
2
- diffusers/
3
  src
 
1
  tmp.png
 
2
  src
app.py CHANGED
@@ -480,6 +480,7 @@ with gr.Blocks(
480
  """
481
  > ❗️ <strong>Note</strong>:
482
  > The Gradio apps use an accelerated version, which may result in a slight reduction in image generation quality.
 
483
  """
484
  )
485
 
@@ -489,8 +490,6 @@ with gr.Blocks(
489
  > - Adjust the <strong>Number of Prompts</strong> slider to add or remove input rows.
490
  > - Type your own prompts directly in the text boxes .
491
  > - You can click any template below to quickly load preset style code and prompts.
492
- > - This model is the open-source version, utilizing [Qwen-Image](https://github.com/QwenLM/Qwen-Image) as the pre-trained model, while the more powerful closed-source version employs Kolors 2.1 as the pre-trained model and will soon be launched on the [KlingAI](https://app.klingai.com/global/?gad_source=1&gad_campaignid=22803840655&gbraid=0AAAAA_AcKMnNNjEHRRI1l9_5z1qK881dO).
493
-
494
  """
495
  )
496
 
 
480
  """
481
  > ❗️ <strong>Note</strong>:
482
  > The Gradio apps use an accelerated version, which may result in a slight reduction in image generation quality.
483
+ > - This demo is the open-source version, utilizing [Qwen-Image](https://github.com/QwenLM/Qwen-Image) as the pre-trained model, while the more powerful closed-source version employs Kolors 2.1 as the pre-trained model and will soon be launched on the [KlingAI](https://app.klingai.com/global/?gad_source=1&gad_campaignid=22803840655&gbraid=0AAAAA_AcKMnNNjEHRRI1l9_5z1qK881dO).
484
  """
485
  )
486
 
 
490
  > - Adjust the <strong>Number of Prompts</strong> slider to add or remove input rows.
491
  > - Type your own prompts directly in the text boxes .
492
  > - You can click any template below to quickly load preset style code and prompts.
 
 
493
  """
494
  )
495
 
lakonlab/models/architecture/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  import sys
2
  sys.path.insert(0, '/home/user/app')
 
3
  from .ddpm import *
4
  from .diffusers import *
5
  from .gmflow import *
 
1
  import sys
2
  sys.path.insert(0, '/home/user/app')
3
+ print('=====insert=====')
4
  from .ddpm import *
5
  from .diffusers import *
6
  from .gmflow import *
lakonlab/models/architecture/diffusers/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pretrained import (
2
+ PretrainedVAE, PretrainedVAEDecoder, PretrainedVAEEncoder, PretrainedVAEQwenImage,
3
+ PretrainedFluxTextEncoder, PretrainedQwenImageTextEncoder, PretrainedStableDiffusion3TextEncoder)
4
+ from .unet import UNet2DConditionModel
5
+ from .flux import FluxTransformer2DModel
6
+ from .dit import DiTTransformer2DModelMod
7
+ from .sd3 import SD3Transformer2DModel
8
+ from .qwen import QwenImageTransformer2DModel
9
+
10
+ __all__ = [
11
+ 'PretrainedVAE', 'PretrainedVAEDecoder', 'PretrainedVAEEncoder', 'PretrainedFluxTextEncoder',
12
+ 'PretrainedQwenImageTextEncoder', 'UNet2DConditionModel', 'FluxTransformer2DModel',
13
+ 'DiTTransformer2DModelMod', 'SD3Transformer2DModel',
14
+ 'QwenImageTransformer2DModel', 'PretrainedVAEQwenImage', 'PretrainedStableDiffusion3TextEncoder',
15
+ ]
lakonlab/models/architecture/diffusers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (842 Bytes). View file
 
lakonlab/models/architecture/diffusers/__pycache__/dit.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
lakonlab/models/architecture/diffusers/__pycache__/flux.cpython-310.pyc ADDED
Binary file (4.68 kB). View file
 
lakonlab/models/architecture/diffusers/__pycache__/pretrained.cpython-310.pyc ADDED
Binary file (8.74 kB). View file
 
lakonlab/models/architecture/diffusers/__pycache__/qwen.cpython-310.pyc ADDED
Binary file (4.16 kB). View file
 
lakonlab/models/architecture/diffusers/__pycache__/sd3.cpython-310.pyc ADDED
Binary file (2.45 kB). View file
 
lakonlab/models/architecture/diffusers/__pycache__/unet.cpython-310.pyc ADDED
Binary file (5.14 kB). View file
 
lakonlab/models/architecture/diffusers/dit.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from typing import Any, Dict, Optional
6
+ from diffusers.models import DiTTransformer2DModel, ModelMixin
7
+ from diffusers.models.attention import BasicTransformerBlock, _chunked_feed_forward, Attention, FeedForward
8
+ from diffusers.models.embeddings import (
9
+ PatchEmbed, Timesteps, CombinedTimestepLabelEmbeddings, TimestepEmbedding, LabelEmbedding)
10
+ from diffusers.models.normalization import AdaLayerNormZero
11
+ from diffusers.configuration_utils import register_to_config
12
+ from mmcv.runner import load_checkpoint, _load_checkpoint, load_state_dict
13
+ from mmcv.cnn import constant_init, xavier_init
14
+ from mmgen.models.builder import MODULES
15
+ from mmgen.utils import get_root_logger
16
+ from ..utils import flex_freeze
17
+
18
+
19
+ class LabelEmbeddingMod(LabelEmbedding):
20
+ def __init__(self, num_classes, hidden_size, dropout_prob=0.0, use_cfg_embedding=True):
21
+ super(LabelEmbedding, self).__init__()
22
+ if dropout_prob > 0:
23
+ assert use_cfg_embedding
24
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
25
+ self.num_classes = num_classes
26
+ self.dropout_prob = dropout_prob
27
+
28
+
29
+ class CombinedTimestepLabelEmbeddingsMod(CombinedTimestepLabelEmbeddings):
30
+ """
31
+ Modified CombinedTimestepLabelEmbeddings for reproducing the original DiT (downscale_freq_shift=0).
32
+ """
33
+ def __init__(
34
+ self, num_classes, embedding_dim, class_dropout_prob=0.1, downscale_freq_shift=0, use_cfg_embedding=True):
35
+ super(CombinedTimestepLabelEmbeddings, self).__init__()
36
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=downscale_freq_shift)
37
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
38
+ self.class_embedder = LabelEmbeddingMod(num_classes, embedding_dim, class_dropout_prob, use_cfg_embedding)
39
+
40
+
41
+ class BasicTransformerBlockMod(BasicTransformerBlock):
42
+ """
43
+ Modified BasicTransformerBlock for reproducing the original DiT with shared time and class
44
+ embeddings across all layers.
45
+ """
46
+ def __init__(
47
+ self,
48
+ dim: int,
49
+ num_attention_heads: int,
50
+ attention_head_dim: int,
51
+ dropout=0.0,
52
+ cross_attention_dim: Optional[int] = None,
53
+ activation_fn: str = 'geglu',
54
+ num_embeds_ada_norm: Optional[int] = None,
55
+ attention_bias: bool = False,
56
+ only_cross_attention: bool = False,
57
+ double_self_attention: bool = False,
58
+ upcast_attention: bool = False,
59
+ norm_elementwise_affine: bool = True,
60
+ norm_type: str = 'layer_norm',
61
+ norm_eps: float = 1e-5,
62
+ final_dropout: bool = False,
63
+ attention_type: str = 'default',
64
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
65
+ ada_norm_bias: Optional[int] = None,
66
+ ff_inner_dim: Optional[int] = None,
67
+ ff_bias: bool = True,
68
+ attention_out_bias: bool = True):
69
+ super(BasicTransformerBlock, self).__init__()
70
+ self.only_cross_attention = only_cross_attention
71
+ self.norm_type = norm_type
72
+ self.num_embeds_ada_norm = num_embeds_ada_norm
73
+
74
+ assert self.norm_type == 'ada_norm_zero'
75
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
76
+ self.attn1 = Attention(
77
+ query_dim=dim,
78
+ heads=num_attention_heads,
79
+ dim_head=attention_head_dim,
80
+ dropout=dropout,
81
+ bias=attention_bias,
82
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
83
+ upcast_attention=upcast_attention,
84
+ out_bias=attention_out_bias,
85
+ )
86
+
87
+ self.norm2 = None
88
+ self.attn2 = None
89
+
90
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
91
+ self.ff = FeedForward(
92
+ dim,
93
+ dropout=dropout,
94
+ activation_fn=activation_fn,
95
+ final_dropout=final_dropout,
96
+ inner_dim=ff_inner_dim,
97
+ bias=ff_bias,
98
+ )
99
+
100
+ self._chunk_size = None
101
+ self._chunk_dim = 0
102
+
103
+ def forward(
104
+ self,
105
+ hidden_states: torch.Tensor,
106
+ attention_mask: Optional[torch.Tensor] = None,
107
+ encoder_hidden_states: Optional[torch.Tensor] = None,
108
+ encoder_attention_mask: Optional[torch.Tensor] = None,
109
+ timestep: Optional[torch.LongTensor] = None,
110
+ cross_attention_kwargs: Dict[str, Any] = None,
111
+ class_labels: Optional[torch.LongTensor] = None,
112
+ emb: Optional[torch.Tensor] = None,
113
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
114
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
115
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype, emb=emb)
116
+
117
+ if cross_attention_kwargs is None:
118
+ cross_attention_kwargs = dict()
119
+ attn_output = self.attn1(
120
+ norm_hidden_states,
121
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
122
+ attention_mask=attention_mask,
123
+ **cross_attention_kwargs)
124
+ attn_output = gate_msa.unsqueeze(1) * attn_output
125
+
126
+ hidden_states = attn_output + hidden_states
127
+ if hidden_states.ndim == 4:
128
+ hidden_states = hidden_states.squeeze(1)
129
+
130
+ norm_hidden_states = self.norm3(hidden_states)
131
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
132
+
133
+ if self._chunk_size is not None:
134
+ # "feed_forward_chunk_size" can be used to save memory
135
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
136
+ else:
137
+ ff_output = self.ff(norm_hidden_states)
138
+
139
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
140
+
141
+ hidden_states = ff_output + hidden_states
142
+ if hidden_states.ndim == 4:
143
+ hidden_states = hidden_states.squeeze(1)
144
+
145
+ return hidden_states
146
+
147
+
148
+ class _DiTTransformer2DModelMod(DiTTransformer2DModel):
149
+
150
+ @register_to_config
151
+ def __init__(
152
+ self,
153
+ class_dropout_prob=0.0,
154
+ num_attention_heads: int = 16,
155
+ attention_head_dim: int = 72,
156
+ in_channels: int = 4,
157
+ out_channels: Optional[int] = None,
158
+ num_layers: int = 28,
159
+ dropout: float = 0.0,
160
+ norm_num_groups: int = 32,
161
+ attention_bias: bool = True,
162
+ sample_size: int = 32,
163
+ patch_size: int = 2,
164
+ activation_fn: str = 'gelu-approximate',
165
+ num_embeds_ada_norm: Optional[int] = 1000,
166
+ upcast_attention: bool = False,
167
+ norm_type: str = 'ada_norm_zero',
168
+ norm_elementwise_affine: bool = False,
169
+ norm_eps: float = 1e-5):
170
+
171
+ super(DiTTransformer2DModel, self).__init__()
172
+
173
+ # Validate inputs.
174
+ if norm_type != "ada_norm_zero":
175
+ raise NotImplementedError(
176
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
177
+ )
178
+ elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None:
179
+ raise ValueError(
180
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
181
+ )
182
+
183
+ # Set some common variables used across the board.
184
+ self.attention_head_dim = attention_head_dim
185
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
186
+ self.out_channels = in_channels if out_channels is None else out_channels
187
+ self.gradient_checkpointing = False
188
+
189
+ # 2. Initialize the position embedding and transformer blocks.
190
+ self.height = self.config.sample_size
191
+ self.width = self.config.sample_size
192
+
193
+ self.patch_size = self.config.patch_size
194
+ self.pos_embed = PatchEmbed(
195
+ height=self.config.sample_size,
196
+ width=self.config.sample_size,
197
+ patch_size=self.config.patch_size,
198
+ in_channels=self.config.in_channels,
199
+ embed_dim=self.inner_dim)
200
+ self.emb = CombinedTimestepLabelEmbeddingsMod(
201
+ num_embeds_ada_norm, self.inner_dim, class_dropout_prob=0.0)
202
+
203
+ self.transformer_blocks = nn.ModuleList([
204
+ BasicTransformerBlockMod(
205
+ self.inner_dim,
206
+ self.config.num_attention_heads,
207
+ self.config.attention_head_dim,
208
+ dropout=self.config.dropout,
209
+ activation_fn=self.config.activation_fn,
210
+ num_embeds_ada_norm=None,
211
+ attention_bias=self.config.attention_bias,
212
+ upcast_attention=self.config.upcast_attention,
213
+ norm_type=norm_type,
214
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
215
+ norm_eps=self.config.norm_eps)
216
+ for _ in range(self.config.num_layers)])
217
+
218
+ # 3. Output blocks.
219
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
220
+ self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
221
+ self.proj_out_2 = nn.Linear(
222
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
223
+
224
+ # https://github.com/facebookresearch/DiT/blob/main/models.py
225
+ def init_weights(self):
226
+ for m in self.modules():
227
+ if isinstance(m, nn.Linear):
228
+ xavier_init(m, distribution='uniform')
229
+ elif isinstance(m, nn.Embedding):
230
+ torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
231
+
232
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d)
233
+ w = self.pos_embed.proj.weight.data
234
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
235
+ nn.init.constant_(self.pos_embed.proj.bias, 0)
236
+
237
+ # Zero-out adaLN modulation layers in DiT blocks
238
+ for m in self.modules():
239
+ if isinstance(m, AdaLayerNormZero):
240
+ constant_init(m.linear, val=0)
241
+
242
+ # Zero-out output layers
243
+ constant_init(self.proj_out_1, val=0)
244
+ constant_init(self.proj_out_2, val=0)
245
+
246
+ def forward(
247
+ self,
248
+ hidden_states: torch.Tensor,
249
+ timestep: Optional[torch.LongTensor] = None,
250
+ class_labels: Optional[torch.LongTensor] = None,
251
+ cross_attention_kwargs: Dict[str, Any] = None):
252
+ # 1. Input
253
+ bs, _, h, w = hidden_states.size()
254
+ height, width = h // self.patch_size, w // self.patch_size
255
+ hidden_states = self.pos_embed(hidden_states)
256
+
257
+ cond_emb = self.emb(
258
+ timestep, class_labels, hidden_dtype=hidden_states.dtype)
259
+ dropout_enabled = self.config.class_dropout_prob > 0 and self.training
260
+ if dropout_enabled:
261
+ uncond_emb = self.emb(timestep, torch.full_like(
262
+ class_labels, self.config.num_embeds_ada_norm), hidden_dtype=hidden_states.dtype)
263
+
264
+ # 2. Blocks
265
+ for block in self.transformer_blocks:
266
+ if dropout_enabled:
267
+ dropout_mask = torch.rand((bs, 1), device=hidden_states.device) < self.config.class_dropout_prob
268
+ emb = torch.where(dropout_mask, uncond_emb, cond_emb)
269
+ else:
270
+ emb = cond_emb
271
+
272
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
273
+
274
+ def create_custom_forward(module, return_dict=None):
275
+ def custom_forward(*inputs):
276
+ if return_dict is not None:
277
+ return module(*inputs, return_dict=return_dict)
278
+ else:
279
+ return module(*inputs)
280
+
281
+ return custom_forward
282
+
283
+ hidden_states = torch.utils.checkpoint.checkpoint(
284
+ create_custom_forward(block),
285
+ hidden_states,
286
+ None,
287
+ None,
288
+ None,
289
+ timestep,
290
+ cross_attention_kwargs,
291
+ class_labels,
292
+ emb,
293
+ use_reentrant=False)
294
+
295
+ else:
296
+ hidden_states = block(
297
+ hidden_states,
298
+ attention_mask=None,
299
+ encoder_hidden_states=None,
300
+ encoder_attention_mask=None,
301
+ timestep=timestep,
302
+ cross_attention_kwargs=cross_attention_kwargs,
303
+ class_labels=class_labels,
304
+ emb=emb)
305
+
306
+ # 3. Output
307
+ if dropout_enabled:
308
+ dropout_mask = torch.rand((bs, 1), device=hidden_states.device) < self.config.class_dropout_prob
309
+ emb = torch.where(dropout_mask, uncond_emb, cond_emb)
310
+ else:
311
+ emb = cond_emb
312
+ shift, scale = self.proj_out_1(F.silu(emb)).chunk(2, dim=1)
313
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
314
+ output = self.proj_out_2(hidden_states).reshape(
315
+ bs, height, width, self.patch_size, self.patch_size, self.out_channels
316
+ ).permute(0, 5, 1, 3, 2, 4).reshape(
317
+ bs, self.out_channels, height * self.patch_size, width * self.patch_size)
318
+
319
+ return output
320
+
321
+
322
+ @MODULES.register_module()
323
+ class DiTTransformer2DModelMod(_DiTTransformer2DModelMod):
324
+
325
+ def __init__(
326
+ self,
327
+ *args,
328
+ freeze=False,
329
+ freeze_exclude=[],
330
+ pretrained=None,
331
+ torch_dtype='float32',
332
+ autocast_dtype=None,
333
+ freeze_exclude_fp32=True,
334
+ freeze_exclude_autocast_dtype='float32',
335
+ checkpointing=True,
336
+ **kwargs):
337
+ super().__init__(*args, **kwargs)
338
+
339
+ self.init_weights(pretrained)
340
+
341
+ if autocast_dtype is not None:
342
+ assert torch_dtype == 'float32'
343
+ self.autocast_dtype = autocast_dtype
344
+
345
+ if torch_dtype is not None:
346
+ self.to(getattr(torch, torch_dtype))
347
+
348
+ self.freeze = freeze
349
+ if self.freeze:
350
+ flex_freeze(
351
+ self,
352
+ exclude_keys=freeze_exclude,
353
+ exclude_fp32=freeze_exclude_fp32,
354
+ exclude_autocast_dtype=freeze_exclude_autocast_dtype)
355
+
356
+ if checkpointing:
357
+ self.enable_gradient_checkpointing()
358
+
359
+ def init_weights(self, pretrained=None):
360
+ super().init_weights()
361
+ if pretrained is not None:
362
+ logger = get_root_logger()
363
+ # load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
364
+ checkpoint = _load_checkpoint(pretrained, map_location='cpu', logger=logger)
365
+ if 'state_dict' in checkpoint:
366
+ state_dict = checkpoint['state_dict']
367
+ else:
368
+ state_dict = checkpoint
369
+ # load from GMDiT V1 model with 1 Gaussian
370
+ p2 = self.config.patch_size * self.config.patch_size
371
+ ori_out_channels = p2 * self.out_channels
372
+ if 'proj_out_2.weight' in state_dict:
373
+ # if this is GMDiT V1 model with 1 Gaussian
374
+ if state_dict['proj_out_2.weight'].size(0) == p2 * (self.out_channels + 1):
375
+ state_dict['proj_out_2.weight'] = state_dict['proj_out_2.weight'].reshape(
376
+ p2, self.out_channels + 1, -1
377
+ )[:, :-1].reshape(ori_out_channels, -1)
378
+ # if this is original DiT with variance prediction
379
+ if state_dict['proj_out_2.weight'].size(0) == 2 * ori_out_channels:
380
+ state_dict['proj_out_2.weight'] = state_dict['proj_out_2.weight'].reshape(
381
+ p2, 2 * self.out_channels, -1
382
+ )[:, :self.out_channels].reshape(ori_out_channels, -1)
383
+ if 'proj_out_2.bias' in state_dict:
384
+ # if this is GMDiT V1 model with 1 Gaussian
385
+ if state_dict['proj_out_2.bias'].size(0) == p2 * (self.out_channels + 1):
386
+ state_dict['proj_out_2.bias'] = state_dict['proj_out_2.bias'].reshape(
387
+ p2, self.out_channels + 1
388
+ )[:, :-1].reshape(ori_out_channels)
389
+ # if this is original DiT with variance prediction
390
+ if state_dict['proj_out_2.bias'].size(0) == 2 * ori_out_channels:
391
+ state_dict['proj_out_2.bias'] = state_dict['proj_out_2.bias'].reshape(
392
+ p2, 2 * self.out_channels
393
+ )[:, :self.out_channels].reshape(ori_out_channels)
394
+ if 'emb.class_embedder.embedding_table.weight' not in state_dict \
395
+ and 'transformer_blocks.0.norm1.emb.class_embedder.embedding_table.weight' in state_dict:
396
+ # convert original diffusers DiT model to our modified DiT model with shared embeddings
397
+ keys_to_remove = []
398
+ state_update = {}
399
+ for k, v in state_dict.items():
400
+ if k.startswith('transformer_blocks.0.norm1.emb.'):
401
+ new_k = k.replace('transformer_blocks.0.norm1.', '')
402
+ state_update[new_k] = v
403
+ if k.startswith('transformer_blocks.') and '.norm1.emb.' in k:
404
+ keys_to_remove.append(k)
405
+ state_dict.update(state_update)
406
+ for k in keys_to_remove:
407
+ del state_dict[k]
408
+ load_state_dict(self, state_dict, logger=logger)
409
+
410
+ def forward(
411
+ self,
412
+ hidden_states: torch.Tensor,
413
+ timestep: Optional[torch.LongTensor] = None,
414
+ class_labels: Optional[torch.LongTensor] = None,
415
+ **kwargs):
416
+ if self.autocast_dtype is not None:
417
+ dtype = getattr(torch, self.autocast_dtype)
418
+ else:
419
+ dtype = hidden_states.dtype
420
+ with torch.autocast(
421
+ device_type='cuda',
422
+ enabled=self.autocast_dtype is not None,
423
+ dtype=dtype if self.autocast_dtype is not None else None):
424
+ return super().forward(
425
+ hidden_states.to(dtype),
426
+ timestep=timestep,
427
+ class_labels=class_labels,
428
+ **kwargs)
lakonlab/models/architecture/diffusers/flux.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing import Optional
4
+ from accelerate import init_empty_weights
5
+ from diffusers.models import FluxTransformer2DModel as _FluxTransformer2DModel
6
+ from peft import LoraConfig
7
+ from mmgen.models.builder import MODULES
8
+ from mmgen.utils import get_root_logger
9
+ from ..utils import flex_freeze
10
+ from lakonlab.runner.checkpoint import load_checkpoint, _load_checkpoint
11
+
12
+
13
+ @MODULES.register_module()
14
+ class FluxTransformer2DModel(_FluxTransformer2DModel):
15
+
16
+ def __init__(
17
+ self,
18
+ *args,
19
+ patch_size=2,
20
+ freeze=False,
21
+ freeze_exclude=[],
22
+ pretrained=None,
23
+ pretrained_lora=None,
24
+ pretrained_lora_scale=1.0,
25
+ torch_dtype='float32',
26
+ freeze_exclude_fp32=True,
27
+ freeze_exclude_autocast_dtype='float32',
28
+ checkpointing=True,
29
+ use_lora=False,
30
+ lora_target_modules=None,
31
+ lora_rank=16,
32
+ **kwargs):
33
+ with init_empty_weights():
34
+ super().__init__(patch_size=1, *args, **kwargs)
35
+ self.patch_size = patch_size
36
+
37
+ self.init_weights(pretrained, pretrained_lora, pretrained_lora_scale)
38
+
39
+ self.use_lora = use_lora
40
+ self.lora_target_modules = lora_target_modules
41
+ self.lora_rank = lora_rank
42
+ if self.use_lora:
43
+ transformer_lora_config = LoraConfig(
44
+ r=lora_rank,
45
+ lora_alpha=lora_rank,
46
+ init_lora_weights='gaussian',
47
+ target_modules=lora_target_modules,
48
+ )
49
+ self.add_adapter(transformer_lora_config)
50
+
51
+ if torch_dtype is not None:
52
+ self.to(getattr(torch, torch_dtype))
53
+
54
+ self.freeze = freeze
55
+ if self.freeze:
56
+ flex_freeze(
57
+ self,
58
+ exclude_keys=freeze_exclude,
59
+ exclude_fp32=freeze_exclude_fp32,
60
+ exclude_autocast_dtype=freeze_exclude_autocast_dtype)
61
+
62
+ if checkpointing:
63
+ self.enable_gradient_checkpointing()
64
+
65
+ def init_weights(self, pretrained=None, pretrained_lora=None, pretrained_lora_scale=1.0):
66
+ if pretrained is not None:
67
+ logger = get_root_logger()
68
+ load_checkpoint(
69
+ self, pretrained, map_location='cpu', strict=False, logger=logger, assign=True)
70
+ if pretrained_lora is not None:
71
+ if not isinstance(pretrained_lora, (list, tuple)):
72
+ assert isinstance(pretrained_lora, str)
73
+ pretrained_lora = [pretrained_lora]
74
+ if not isinstance(pretrained_lora_scale, (list, tuple)):
75
+ assert isinstance(pretrained_lora_scale, (int, float))
76
+ pretrained_lora_scale = [pretrained_lora_scale]
77
+ for pretrained_lora_single, pretrained_lora_scale_single in zip(pretrained_lora, pretrained_lora_scale):
78
+ lora_state_dict = _load_checkpoint(
79
+ pretrained_lora_single, map_location='cpu', logger=logger)
80
+ self.load_lora_adapter(lora_state_dict)
81
+ self.fuse_lora(lora_scale=pretrained_lora_scale_single)
82
+ self.unload_lora()
83
+
84
+ @staticmethod
85
+ def _prepare_latent_image_ids(height, width, device, dtype):
86
+ """
87
+ Copied from Diffusers
88
+ """
89
+ latent_image_ids = torch.zeros(height, width, 3)
90
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
91
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
92
+
93
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
94
+
95
+ latent_image_ids = latent_image_ids.reshape(
96
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels)
97
+
98
+ return latent_image_ids.to(device=device, dtype=dtype)
99
+
100
+ def patchify(self, latents):
101
+ if self.patch_size > 1:
102
+ bs, c, h, w = latents.size()
103
+ latents = latents.reshape(
104
+ bs, c, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size
105
+ ).permute(
106
+ 0, 1, 3, 5, 2, 4
107
+ ).reshape(
108
+ bs, c * self.patch_size * self.patch_size, h // self.patch_size, w // self.patch_size)
109
+ return latents
110
+
111
+ def unpatchify(self, latents):
112
+ if self.patch_size > 1:
113
+ bs, c, h, w = latents.size()
114
+ latents = latents.reshape(
115
+ bs, c // (self.patch_size * self.patch_size), self.patch_size, self.patch_size, h, w
116
+ ).permute(
117
+ 0, 1, 4, 2, 5, 3
118
+ ).reshape(
119
+ bs, c // (self.patch_size * self.patch_size), h * self.patch_size, w * self.patch_size)
120
+ return latents
121
+
122
+ def forward(
123
+ self,
124
+ hidden_states: torch.Tensor,
125
+ timestep: torch.Tensor,
126
+ encoder_hidden_states: torch.Tensor = None,
127
+ pooled_projections: torch.Tensor = None,
128
+ mask: Optional[torch.Tensor] = None,
129
+ masked_image_latents: Optional[torch.Tensor] = None,
130
+ **kwargs):
131
+ hidden_states = self.patchify(hidden_states)
132
+ bs, c, h, w = hidden_states.size()
133
+ dtype = hidden_states.dtype
134
+ device = hidden_states.device
135
+ hidden_states = hidden_states.reshape(bs, c, h * w).permute(0, 2, 1)
136
+ img_ids = self._prepare_latent_image_ids(
137
+ h, w, device, dtype)
138
+ txt_ids = img_ids.new_zeros((encoder_hidden_states.shape[-2], 3))
139
+
140
+ # Flux fill
141
+ if mask is not None and masked_image_latents is not None:
142
+ hidden_states = torch.cat(
143
+ (hidden_states, masked_image_latents.to(dtype=dtype), mask.to(dtype=dtype)), dim=-1)
144
+
145
+ output = super().forward(
146
+ hidden_states=hidden_states,
147
+ encoder_hidden_states=encoder_hidden_states.to(dtype),
148
+ pooled_projections=pooled_projections.to(dtype),
149
+ timestep=timestep,
150
+ img_ids=img_ids,
151
+ txt_ids=txt_ids,
152
+ return_dict=False,
153
+ **kwargs)[0]
154
+
155
+ output = output.permute(0, 2, 1).reshape(bs, self.out_channels, h, w)
156
+ return self.unpatchify(output)
lakonlab/models/architecture/diffusers/pretrained.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from diffusers.models import AutoencoderKL, AutoencoderKLQwenImage
7
+ from diffusers.pipelines import FluxPipeline, QwenImagePipeline, StableDiffusion3Pipeline
8
+ from mmgen.models.builder import MODULES
9
+
10
+ # Suppress truncation warnings from transformers and diffusers
11
+ for name in (
12
+ 'transformers.tokenization_utils_base',
13
+ 'transformers.tokenization_utils',
14
+ 'transformers.tokenization_utils_fast'):
15
+ logging.getLogger(name).setLevel(logging.ERROR)
16
+
17
+ for name, logger in logging.root.manager.loggerDict.items():
18
+ if isinstance(logger, logging.Logger) and (name.startswith('diffusers.pipelines.')):
19
+ logger.setLevel(logging.ERROR)
20
+
21
+
22
+ @MODULES.register_module()
23
+ class PretrainedVAE(nn.Module):
24
+ def __init__(self,
25
+ from_pretrained=None,
26
+ del_encoder=False,
27
+ del_decoder=False,
28
+ use_slicing=False,
29
+ freeze=True,
30
+ eval_mode=True,
31
+ torch_dtype='float32',
32
+ **kwargs):
33
+ super().__init__()
34
+ if torch_dtype is not None:
35
+ kwargs.update(torch_dtype=getattr(torch, torch_dtype))
36
+ self.vae = AutoencoderKL.from_pretrained(
37
+ from_pretrained, **kwargs)
38
+ if del_encoder:
39
+ del self.vae.encoder
40
+ if del_decoder:
41
+ del self.vae.decoder
42
+ if use_slicing:
43
+ self.vae.enable_slicing()
44
+ self.freeze = freeze
45
+ self.eval_mode = eval_mode
46
+ if self.freeze:
47
+ self.requires_grad_(False)
48
+ if self.eval_mode:
49
+ self.eval()
50
+ self.vae.set_use_memory_efficient_attention_xformers(
51
+ not hasattr(torch.nn.functional, 'scaled_dot_product_attention'))
52
+
53
+ def train(self, mode=True):
54
+ mode = mode and (not self.eval_mode)
55
+ return super().train(mode)
56
+
57
+ def forward(self, *args, **kwargs):
58
+ return self.vae(*args, return_dict=False, **kwargs)[0]
59
+
60
+ def encode(self, img):
61
+ scaling_factor = self.vae.config.scaling_factor
62
+ shift_factor = self.vae.config.shift_factor
63
+ if scaling_factor is None:
64
+ scaling_factor = 1.0
65
+ if shift_factor is None:
66
+ shift_factor = 0.0
67
+ return (self.vae.encode(img).latent_dist.sample() - shift_factor) * scaling_factor
68
+
69
+ def decode(self, code):
70
+ scaling_factor = self.vae.config.scaling_factor
71
+ shift_factor = self.vae.config.shift_factor
72
+ if scaling_factor is None:
73
+ scaling_factor = 1.0
74
+ if shift_factor is None:
75
+ shift_factor = 0.0
76
+ return self.vae.decode(code / scaling_factor + shift_factor, return_dict=False)[0]
77
+
78
+
79
+ @MODULES.register_module()
80
+ class PretrainedVAEDecoder(PretrainedVAE):
81
+ def __init__(self, **kwargs):
82
+ super().__init__(
83
+ del_encoder=True,
84
+ del_decoder=False,
85
+ **kwargs)
86
+
87
+ def forward(self, code):
88
+ return super().decode(code)
89
+
90
+
91
+ @MODULES.register_module()
92
+ class PretrainedVAEEncoder(PretrainedVAE):
93
+ def __init__(self, **kwargs):
94
+ super().__init__(
95
+ del_encoder=False,
96
+ del_decoder=True,
97
+ **kwargs)
98
+
99
+ def forward(self, img):
100
+ return super().encode(img)
101
+
102
+
103
+ @MODULES.register_module()
104
+ class PretrainedVAEQwenImage(nn.Module):
105
+ def __init__(self,
106
+ from_pretrained=None,
107
+ use_slicing=False,
108
+ freeze=True,
109
+ eval_mode=True,
110
+ torch_dtype='float32',
111
+ **kwargs):
112
+ super().__init__()
113
+ if torch_dtype is not None:
114
+ kwargs.update(torch_dtype=getattr(torch, torch_dtype))
115
+ self.vae = AutoencoderKLQwenImage.from_pretrained(
116
+ from_pretrained, **kwargs)
117
+ if use_slicing:
118
+ self.vae.enable_slicing()
119
+ self.freeze = freeze
120
+ self.eval_mode = eval_mode
121
+ if self.freeze:
122
+ self.requires_grad_(False)
123
+ if self.eval_mode:
124
+ self.eval()
125
+
126
+ def train(self, mode=True):
127
+ mode = mode and (not self.eval_mode)
128
+ return super().train(mode)
129
+
130
+ def forward(self, *args, **kwargs):
131
+ return self.vae(*args, return_dict=False, **kwargs)[0]
132
+
133
+ def encode(self, img):
134
+ device = img.device
135
+ dtype = img.dtype
136
+ latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=dtype).view(
137
+ 1, self.vae.config.z_dim, 1, 1, 1)
138
+ latents_std = torch.tensor(self.vae.config.latents_std, device=device, dtype=dtype).view(
139
+ 1, self.vae.config.z_dim, 1, 1, 1)
140
+ return ((self.vae.encode(img.unsqueeze(-3)).latent_dist.sample() - latents_mean) / latents_std).squeeze(-3)
141
+
142
+ def decode(self, code):
143
+ device = code.device
144
+ dtype = code.dtype
145
+ latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=dtype).view(
146
+ 1, self.vae.config.z_dim, 1, 1, 1)
147
+ latents_std = torch.tensor(self.vae.config.latents_std, device=device, dtype=dtype).view(
148
+ 1, self.vae.config.z_dim, 1, 1, 1)
149
+ return self.vae.decode(code.unsqueeze(-3) * latents_std + latents_mean, return_dict=False)[0].squeeze(-3)
150
+
151
+
152
+ @MODULES.register_module()
153
+ class PretrainedFluxTextEncoder(nn.Module):
154
+ def __init__(self,
155
+ from_pretrained='black-forest-labs/FLUX.1-dev',
156
+ freeze=True,
157
+ eval_mode=True,
158
+ torch_dtype='bfloat16',
159
+ max_sequence_length=512,
160
+ **kwargs):
161
+ super().__init__()
162
+ self.max_sequence_length = max_sequence_length
163
+ self.pipeline = FluxPipeline.from_pretrained(
164
+ from_pretrained,
165
+ scheduler=None,
166
+ vae=None,
167
+ transformer=None,
168
+ image_encoder=None,
169
+ feature_extractor=None,
170
+ torch_dtype=getattr(torch, torch_dtype),
171
+ **kwargs)
172
+ self.text_encoder = self.pipeline.text_encoder
173
+ self.text_encoder_2 = self.pipeline.text_encoder_2
174
+ self.freeze = freeze
175
+ self.eval_mode = eval_mode
176
+ if self.freeze:
177
+ self.requires_grad_(False)
178
+ if self.eval_mode:
179
+ self.eval()
180
+
181
+ def train(self, mode=True):
182
+ mode = mode and (not self.eval_mode)
183
+ return super().train(mode)
184
+
185
+ def forward(self, prompt, prompt_2=None):
186
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.pipeline.encode_prompt(
187
+ prompt, prompt_2=prompt_2, max_sequence_length=self.max_sequence_length)
188
+ return dict(
189
+ encoder_hidden_states=prompt_embeds,
190
+ pooled_projections=pooled_prompt_embeds)
191
+
192
+
193
+ @MODULES.register_module()
194
+ class PretrainedQwenImageTextEncoder(nn.Module):
195
+ def __init__(self,
196
+ from_pretrained='Qwen/Qwen-Image',
197
+ freeze=True,
198
+ eval_mode=True,
199
+ torch_dtype='bfloat16',
200
+ max_sequence_length=512,
201
+ pad_seq_len=None,
202
+ **kwargs):
203
+ super().__init__()
204
+ self.max_sequence_length = max_sequence_length
205
+ if pad_seq_len is not None:
206
+ assert pad_seq_len >= max_sequence_length
207
+ self.pad_seq_len = pad_seq_len
208
+ self.pipeline = QwenImagePipeline.from_pretrained(
209
+ from_pretrained,
210
+ scheduler=None,
211
+ vae=None,
212
+ transformer=None,
213
+ torch_dtype=getattr(torch, torch_dtype),
214
+ **kwargs)
215
+ self.text_encoder = self.pipeline.text_encoder
216
+ self.freeze = freeze
217
+ self.eval_mode = eval_mode
218
+ if self.freeze:
219
+ self.requires_grad_(False)
220
+ if self.eval_mode:
221
+ self.eval()
222
+
223
+ def train(self, mode=True):
224
+ mode = mode and (not self.eval_mode)
225
+ return super().train(mode)
226
+
227
+ def forward(self, prompt):
228
+ prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
229
+ prompt, max_sequence_length=self.max_sequence_length)
230
+ if self.pad_seq_len is not None:
231
+ pad_len = self.pad_seq_len - prompt_embeds.size(1)
232
+ prompt_embeds = F.pad(
233
+ prompt_embeds, (0, 0, 0, pad_len), value=0.0)
234
+ prompt_embeds_mask = F.pad(
235
+ prompt_embeds_mask, (0, pad_len), value=0.0)
236
+ return dict(
237
+ encoder_hidden_states=prompt_embeds,
238
+ encoder_hidden_states_mask=prompt_embeds_mask)
239
+
240
+
241
+ @MODULES.register_module()
242
+ class PretrainedStableDiffusion3TextEncoder(nn.Module):
243
+ def __init__(self,
244
+ from_pretrained='stabilityai/stable-diffusion-3.5-large',
245
+ freeze=True,
246
+ eval_mode=True,
247
+ torch_dtype='float32',
248
+ max_sequence_length=256,
249
+ **kwargs):
250
+ super().__init__()
251
+ self.max_sequence_length = max_sequence_length
252
+ self.pipeline = StableDiffusion3Pipeline.from_pretrained(
253
+ from_pretrained,
254
+ scheduler=None,
255
+ vae=None,
256
+ transformer=None,
257
+ image_encoder=None,
258
+ feature_extractor=None,
259
+ torch_dtype=getattr(torch, torch_dtype),
260
+ **kwargs)
261
+ self.text_encoder = self.pipeline.text_encoder
262
+ self.text_encoder_2 = self.pipeline.text_encoder_2
263
+ self.text_encoder_3 = self.pipeline.text_encoder_3
264
+ self.freeze = freeze
265
+ self.eval_mode = eval_mode
266
+ if self.freeze:
267
+ self.requires_grad_(False)
268
+ if self.eval_mode:
269
+ self.eval()
270
+
271
+ def train(self, mode=True):
272
+ mode = mode and (not self.eval_mode)
273
+ return super().train(mode)
274
+
275
+ def forward(self, prompt, prompt_2=None, prompt_3=None):
276
+ prompt_embeds, _, pooled_prompt_embeds, _ = self.pipeline.encode_prompt(
277
+ prompt, prompt_2=prompt_2, prompt_3=prompt_3, do_classifier_free_guidance=False,
278
+ max_sequence_length=self.max_sequence_length)
279
+ return dict(
280
+ encoder_hidden_states=prompt_embeds,
281
+ pooled_projections=pooled_prompt_embeds)
lakonlab/models/architecture/diffusers/qwen.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from accelerate import init_empty_weights
4
+ from diffusers.models import QwenImageTransformer2DModel as _QwenImageTransformer2DModel
5
+ from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_qwen_lora_to_diffusers
6
+ from peft import LoraConfig
7
+ from mmgen.models.builder import MODULES
8
+ from mmgen.utils import get_root_logger
9
+ from ..utils import flex_freeze
10
+ from lakonlab.runner.checkpoint import load_checkpoint, _load_checkpoint
11
+
12
+
13
+ @MODULES.register_module()
14
+ class QwenImageTransformer2DModel(_QwenImageTransformer2DModel):
15
+
16
+ def __init__(
17
+ self,
18
+ *args,
19
+ patch_size=2,
20
+ freeze=False,
21
+ freeze_exclude=[],
22
+ pretrained=None,
23
+ pretrained_lora=None,
24
+ pretrained_lora_scale=1.0,
25
+ torch_dtype='float32',
26
+ freeze_exclude_fp32=True,
27
+ freeze_exclude_autocast_dtype='float32',
28
+ checkpointing=True,
29
+ use_lora=False,
30
+ lora_target_modules=None,
31
+ lora_rank=16,
32
+ **kwargs):
33
+ with init_empty_weights():
34
+ super().__init__(*args, patch_size=1, **kwargs)
35
+ self.patch_size = patch_size
36
+
37
+ self.init_weights(pretrained, pretrained_lora, pretrained_lora_scale)
38
+
39
+ self.use_lora = use_lora
40
+ self.lora_target_modules = lora_target_modules
41
+ self.lora_rank = lora_rank
42
+ if self.use_lora:
43
+ transformer_lora_config = LoraConfig(
44
+ r=lora_rank,
45
+ lora_alpha=lora_rank,
46
+ init_lora_weights='gaussian',
47
+ target_modules=lora_target_modules,
48
+ )
49
+ self.add_adapter(transformer_lora_config)
50
+
51
+ if torch_dtype is not None:
52
+ self.to(getattr(torch, torch_dtype))
53
+
54
+ self.freeze = freeze
55
+ if self.freeze:
56
+ flex_freeze(
57
+ self,
58
+ exclude_keys=freeze_exclude,
59
+ exclude_fp32=freeze_exclude_fp32,
60
+ exclude_autocast_dtype=freeze_exclude_autocast_dtype)
61
+
62
+ if checkpointing:
63
+ self.enable_gradient_checkpointing()
64
+
65
+ def init_weights(self, pretrained=None, pretrained_lora=None, pretrained_lora_scale=1.0):
66
+ if pretrained is not None:
67
+ logger = get_root_logger()
68
+ load_checkpoint(
69
+ self, pretrained, map_location='cpu', strict=False, logger=logger, assign=True)
70
+ if pretrained_lora is not None:
71
+ if not isinstance(pretrained_lora, (list, tuple)):
72
+ assert isinstance(pretrained_lora, str)
73
+ pretrained_lora = [pretrained_lora]
74
+ if not isinstance(pretrained_lora_scale, (list, tuple)):
75
+ assert isinstance(pretrained_lora_scale, (int, float))
76
+ pretrained_lora_scale = [pretrained_lora_scale]
77
+ for pretrained_lora_single, pretrained_lora_scale_single in zip(pretrained_lora, pretrained_lora_scale):
78
+ lora_state_dict = _load_checkpoint(
79
+ pretrained_lora_single, map_location='cpu', logger=logger)
80
+ lora_state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(lora_state_dict)
81
+ self.load_lora_adapter(lora_state_dict)
82
+ self.fuse_lora(lora_scale=pretrained_lora_scale_single)
83
+ self.unload_lora()
84
+
85
+ def patchify(self, latents):
86
+ if self.patch_size > 1:
87
+ bs, c, h, w = latents.size()
88
+ latents = latents.reshape(
89
+ bs, c, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size
90
+ ).permute(
91
+ 0, 1, 3, 5, 2, 4
92
+ ).reshape(
93
+ bs, c * self.patch_size * self.patch_size, h // self.patch_size, w // self.patch_size)
94
+ return latents
95
+
96
+ def unpatchify(self, latents):
97
+ if self.patch_size > 1:
98
+ bs, c, h, w = latents.size()
99
+ latents = latents.reshape(
100
+ bs, c // (self.patch_size * self.patch_size), self.patch_size, self.patch_size, h, w
101
+ ).permute(
102
+ 0, 1, 4, 2, 5, 3
103
+ ).reshape(
104
+ bs, c // (self.patch_size * self.patch_size), h * self.patch_size, w * self.patch_size)
105
+ return latents
106
+
107
+ def forward(
108
+ self,
109
+ hidden_states: torch.Tensor,
110
+ timestep: torch.Tensor,
111
+ encoder_hidden_states: torch.Tensor = None,
112
+ encoder_hidden_states_mask: torch.Tensor = None,
113
+ **kwargs):
114
+ hidden_states = self.patchify(hidden_states)
115
+ bs, c, h, w = hidden_states.size()
116
+ dtype = hidden_states.dtype
117
+ hidden_states = hidden_states.reshape(bs, c, h * w).permute(0, 2, 1)
118
+ img_shapes = [[(1, h, w)]]
119
+ if encoder_hidden_states_mask is not None:
120
+ txt_seq_lens = encoder_hidden_states_mask.sum(dim=1)
121
+ max_txt_seq_len = txt_seq_lens.max()
122
+ encoder_hidden_states = encoder_hidden_states[:, :max_txt_seq_len]
123
+ encoder_hidden_states_mask = encoder_hidden_states_mask[:, :max_txt_seq_len]
124
+ txt_seq_lens = txt_seq_lens.tolist()
125
+ else:
126
+ txt_seq_lens = None
127
+
128
+ output = super().forward(
129
+ hidden_states=hidden_states,
130
+ encoder_hidden_states=encoder_hidden_states.to(dtype),
131
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
132
+ timestep=timestep,
133
+ img_shapes=img_shapes,
134
+ txt_seq_lens=txt_seq_lens,
135
+ return_dict=False,
136
+ **kwargs)[0]
137
+
138
+ output = output.permute(0, 2, 1).reshape(bs, self.out_channels, h, w)
139
+ return self.unpatchify(output)
lakonlab/models/architecture/diffusers/sd3.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from accelerate import init_empty_weights
4
+ from diffusers.models import SD3Transformer2DModel as _SD3Transformer2DModel
5
+ from peft import LoraConfig
6
+ from mmgen.models.builder import MODULES
7
+ from mmgen.utils import get_root_logger
8
+ from ..utils import flex_freeze
9
+ from lakonlab.runner.checkpoint import load_checkpoint
10
+
11
+
12
+ @MODULES.register_module()
13
+ class SD3Transformer2DModel(_SD3Transformer2DModel):
14
+
15
+ def __init__(
16
+ self,
17
+ *args,
18
+ freeze=False,
19
+ freeze_exclude=[],
20
+ pretrained=None,
21
+ torch_dtype='float32',
22
+ freeze_exclude_fp32=True,
23
+ freeze_exclude_autocast_dtype='float32',
24
+ checkpointing=True,
25
+ use_lora=False,
26
+ lora_target_modules=None,
27
+ lora_rank=16,
28
+ **kwargs):
29
+ with init_empty_weights():
30
+ super().__init__(*args, **kwargs)
31
+ self.init_weights(pretrained)
32
+
33
+ self.use_lora = use_lora
34
+ self.lora_target_modules = lora_target_modules
35
+ self.lora_rank = lora_rank
36
+ if self.use_lora:
37
+ transformer_lora_config = LoraConfig(
38
+ r=lora_rank,
39
+ lora_alpha=lora_rank,
40
+ init_lora_weights='gaussian',
41
+ target_modules=lora_target_modules,
42
+ )
43
+ self.add_adapter(transformer_lora_config)
44
+
45
+ if torch_dtype is not None:
46
+ self.to(getattr(torch, torch_dtype))
47
+
48
+ self.freeze = freeze
49
+ if self.freeze:
50
+ flex_freeze(
51
+ self,
52
+ exclude_keys=freeze_exclude,
53
+ exclude_fp32=freeze_exclude_fp32,
54
+ exclude_autocast_dtype=freeze_exclude_autocast_dtype)
55
+
56
+ if checkpointing:
57
+ self.enable_gradient_checkpointing()
58
+
59
+ def init_weights(self, pretrained=None):
60
+ if pretrained is not None:
61
+ logger = get_root_logger()
62
+ load_checkpoint(
63
+ self, pretrained, map_location='cpu', strict=False, logger=logger, assign=True)
64
+
65
+ def forward(
66
+ self,
67
+ hidden_states: torch.Tensor,
68
+ timestep: torch.Tensor,
69
+ encoder_hidden_states: torch.Tensor = None,
70
+ pooled_projections: torch.Tensor = None,
71
+ **kwargs):
72
+ dtype = hidden_states.dtype
73
+
74
+ return super().forward(
75
+ hidden_states=hidden_states,
76
+ encoder_hidden_states=encoder_hidden_states.to(dtype),
77
+ pooled_projections=pooled_projections.to(dtype),
78
+ timestep=timestep,
79
+ return_dict=False,
80
+ **kwargs)[0]
lakonlab/models/architecture/diffusers/unet.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from typing import Dict, Any, Optional, Union, Tuple
5
+ from collections import OrderedDict
6
+ from diffusers.models import UNet2DConditionModel as _UNet2DConditionModel
7
+ from mmcv.runner import _load_checkpoint, load_state_dict
8
+ from mmgen.models.builder import MODULES
9
+ from mmgen.utils import get_root_logger
10
+ from ..utils import flex_freeze
11
+
12
+
13
+ def ceildiv(a, b):
14
+ return -(a // -b)
15
+
16
+
17
+ def unet_enc(
18
+ unet,
19
+ sample: torch.FloatTensor,
20
+ timestep: Union[torch.Tensor, float, int],
21
+ encoder_hidden_states: torch.Tensor,
22
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
23
+ added_cond_kwargs=None):
24
+ # 0. center input if necessary
25
+ if unet.config.center_input_sample:
26
+ sample = 2 * sample - 1.0
27
+
28
+ # 1. time
29
+ t_emb = unet.get_time_embed(sample=sample, timestep=timestep)
30
+ emb = unet.time_embedding(t_emb)
31
+ aug_emb = unet.get_aug_embed(
32
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs)
33
+ emb = emb + aug_emb if aug_emb is not None else emb
34
+
35
+ if unet.time_embed_act is not None:
36
+ emb = unet.time_embed_act(emb)
37
+
38
+ encoder_hidden_states = unet.process_encoder_hidden_states(
39
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs)
40
+
41
+ # 2. pre-process
42
+ sample = unet.conv_in(sample)
43
+
44
+ # 3. down
45
+ down_block_res_samples = (sample,)
46
+ for downsample_block in unet.down_blocks:
47
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
48
+ sample, res_samples = downsample_block(
49
+ hidden_states=sample,
50
+ temb=emb,
51
+ encoder_hidden_states=encoder_hidden_states,
52
+ cross_attention_kwargs=cross_attention_kwargs,
53
+ )
54
+ else:
55
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
56
+
57
+ down_block_res_samples += res_samples
58
+
59
+ return emb, down_block_res_samples, sample
60
+
61
+
62
+ def unet_dec(
63
+ unet,
64
+ emb,
65
+ down_block_res_samples,
66
+ sample,
67
+ encoder_hidden_states: torch.Tensor,
68
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
69
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
70
+ mid_block_additional_residual: Optional[torch.Tensor] = None):
71
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
72
+
73
+ if is_controlnet:
74
+ new_down_block_res_samples = ()
75
+
76
+ for down_block_res_sample, down_block_additional_residual in zip(
77
+ down_block_res_samples, down_block_additional_residuals):
78
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
79
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
80
+
81
+ down_block_res_samples = new_down_block_res_samples
82
+
83
+ # 4. mid
84
+ if unet.mid_block is not None:
85
+ if hasattr(unet.mid_block, "has_cross_attention") and unet.mid_block.has_cross_attention:
86
+ sample = unet.mid_block(
87
+ sample,
88
+ emb,
89
+ encoder_hidden_states=encoder_hidden_states,
90
+ cross_attention_kwargs=cross_attention_kwargs,
91
+ )
92
+ else:
93
+ sample = unet.mid_block(sample, emb)
94
+
95
+ if is_controlnet:
96
+ sample = sample + mid_block_additional_residual
97
+
98
+ # 5. up
99
+ for i, upsample_block in enumerate(unet.up_blocks):
100
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
101
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
102
+
103
+ if hasattr(upsample_block, 'has_cross_attention') and upsample_block.has_cross_attention:
104
+ sample = upsample_block(
105
+ hidden_states=sample,
106
+ temb=emb,
107
+ res_hidden_states_tuple=res_samples,
108
+ encoder_hidden_states=encoder_hidden_states,
109
+ cross_attention_kwargs=cross_attention_kwargs,
110
+ )
111
+ else:
112
+ sample = upsample_block(
113
+ hidden_states=sample,
114
+ temb=emb,
115
+ res_hidden_states_tuple=res_samples,
116
+ )
117
+
118
+ # 6. post-process
119
+ if unet.conv_norm_out:
120
+ sample = unet.conv_norm_out(sample)
121
+ sample = unet.conv_act(sample)
122
+ sample = unet.conv_out(sample)
123
+
124
+ return sample
125
+
126
+
127
+ @MODULES.register_module()
128
+ class UNet2DConditionModel(_UNet2DConditionModel):
129
+ def __init__(self,
130
+ *args,
131
+ freeze=True,
132
+ freeze_exclude=[],
133
+ pretrained=None,
134
+ torch_dtype='float32',
135
+ freeze_exclude_fp32=True,
136
+ freeze_exclude_autocast_dtype='float32',
137
+ **kwargs):
138
+ super().__init__(*args, **kwargs)
139
+
140
+ self.init_weights(pretrained)
141
+ if torch_dtype is not None:
142
+ self.to(getattr(torch, torch_dtype))
143
+
144
+ self.set_use_memory_efficient_attention_xformers(
145
+ not hasattr(torch.nn.functional, 'scaled_dot_product_attention'))
146
+
147
+ self.freeze = freeze
148
+ if self.freeze:
149
+ flex_freeze(
150
+ self,
151
+ exclude_keys=freeze_exclude,
152
+ exclude_fp32=freeze_exclude_fp32,
153
+ exclude_autocast_dtype=freeze_exclude_autocast_dtype)
154
+
155
+ def init_weights(self, pretrained):
156
+ if pretrained is not None:
157
+ logger = get_root_logger()
158
+ # load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
159
+ checkpoint = _load_checkpoint(pretrained, map_location='cpu', logger=logger)
160
+ if 'state_dict' in checkpoint:
161
+ state_dict = checkpoint['state_dict']
162
+ else:
163
+ state_dict = checkpoint
164
+ metadata = getattr(state_dict, '_metadata', OrderedDict())
165
+ state_dict._metadata = metadata
166
+ assert self.conv_in.weight.shape[1] == self.conv_out.weight.shape[0]
167
+ if state_dict['conv_in.weight'].size() != self.conv_in.weight.size():
168
+ assert state_dict['conv_in.weight'].shape[1] == state_dict['conv_out.weight'].shape[0]
169
+ src_chn = state_dict['conv_in.weight'].shape[1]
170
+ tgt_chn = self.conv_in.weight.shape[1]
171
+ assert src_chn < tgt_chn
172
+ convert_mat_out = torch.tile(torch.eye(src_chn), (ceildiv(tgt_chn, src_chn), 1))
173
+ convert_mat_out = convert_mat_out[:tgt_chn]
174
+ convert_mat_in = F.normalize(convert_mat_out.pinverse(), dim=-1)
175
+ state_dict['conv_out.weight'] = torch.einsum(
176
+ 'ts,scxy->tcxy', convert_mat_out, state_dict['conv_out.weight'])
177
+ state_dict['conv_out.bias'] = torch.einsum(
178
+ 'ts,s->t', convert_mat_out, state_dict['conv_out.bias'])
179
+ state_dict['conv_in.weight'] = torch.einsum(
180
+ 'st,csxy->ctxy', convert_mat_in, state_dict['conv_in.weight'])
181
+ load_state_dict(self, state_dict, logger=logger)
182
+
183
+ def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
184
+ dtype = sample.dtype
185
+ return super().forward(
186
+ sample, timestep, encoder_hidden_states, return_dict=False, **kwargs)[0].to(dtype)
187
+
188
+ def forward_enc(self, sample, timestep, encoder_hidden_states, **kwargs):
189
+ return unet_enc(self, sample, timestep, encoder_hidden_states, **kwargs)
190
+
191
+ def forward_dec(self, emb, down_block_res_samples, sample, encoder_hidden_states, **kwargs):
192
+ return unet_dec(self, emb, down_block_res_samples, sample, encoder_hidden_states, **kwargs)
models/lakonlab/models/architecture/diffusers/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pretrained import (
2
+ PretrainedVAE, PretrainedVAEDecoder, PretrainedVAEEncoder, PretrainedVAEQwenImage,
3
+ PretrainedFluxTextEncoder, PretrainedQwenImageTextEncoder, PretrainedStableDiffusion3TextEncoder)
4
+ from .unet import UNet2DConditionModel
5
+ from .flux import FluxTransformer2DModel
6
+ from .dit import DiTTransformer2DModelMod
7
+ from .sd3 import SD3Transformer2DModel
8
+ from .qwen import QwenImageTransformer2DModel
9
+
10
+ __all__ = [
11
+ 'PretrainedVAE', 'PretrainedVAEDecoder', 'PretrainedVAEEncoder', 'PretrainedFluxTextEncoder',
12
+ 'PretrainedQwenImageTextEncoder', 'UNet2DConditionModel', 'FluxTransformer2DModel',
13
+ 'DiTTransformer2DModelMod', 'SD3Transformer2DModel',
14
+ 'QwenImageTransformer2DModel', 'PretrainedVAEQwenImage', 'PretrainedStableDiffusion3TextEncoder',
15
+ ]
models/lakonlab/models/architecture/diffusers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (828 Bytes). View file
 
models/lakonlab/models/architecture/diffusers/__pycache__/dit.cpython-310.pyc ADDED
Binary file (12.4 kB). View file
 
models/lakonlab/models/architecture/diffusers/__pycache__/flux.cpython-310.pyc ADDED
Binary file (4.66 kB). View file
 
models/lakonlab/models/architecture/diffusers/__pycache__/pretrained.cpython-310.pyc ADDED
Binary file (8.72 kB). View file
 
models/lakonlab/models/architecture/diffusers/__pycache__/qwen.cpython-310.pyc ADDED
Binary file (4.14 kB). View file
 
models/lakonlab/models/architecture/diffusers/__pycache__/sd3.cpython-310.pyc ADDED
Binary file (2.43 kB). View file
 
models/lakonlab/models/architecture/diffusers/__pycache__/unet.cpython-310.pyc ADDED
Binary file (5.13 kB). View file
 
models/lakonlab/models/architecture/diffusers/dit.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from typing import Any, Dict, Optional
6
+ from diffusers.models import DiTTransformer2DModel, ModelMixin
7
+ from diffusers.models.attention import BasicTransformerBlock, _chunked_feed_forward, Attention, FeedForward
8
+ from diffusers.models.embeddings import (
9
+ PatchEmbed, Timesteps, CombinedTimestepLabelEmbeddings, TimestepEmbedding, LabelEmbedding)
10
+ from diffusers.models.normalization import AdaLayerNormZero
11
+ from diffusers.configuration_utils import register_to_config
12
+ from mmcv.runner import load_checkpoint, _load_checkpoint, load_state_dict
13
+ from mmcv.cnn import constant_init, xavier_init
14
+ from mmgen.models.builder import MODULES
15
+ from mmgen.utils import get_root_logger
16
+ from ..utils import flex_freeze
17
+
18
+
19
+ class LabelEmbeddingMod(LabelEmbedding):
20
+ def __init__(self, num_classes, hidden_size, dropout_prob=0.0, use_cfg_embedding=True):
21
+ super(LabelEmbedding, self).__init__()
22
+ if dropout_prob > 0:
23
+ assert use_cfg_embedding
24
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
25
+ self.num_classes = num_classes
26
+ self.dropout_prob = dropout_prob
27
+
28
+
29
+ class CombinedTimestepLabelEmbeddingsMod(CombinedTimestepLabelEmbeddings):
30
+ """
31
+ Modified CombinedTimestepLabelEmbeddings for reproducing the original DiT (downscale_freq_shift=0).
32
+ """
33
+ def __init__(
34
+ self, num_classes, embedding_dim, class_dropout_prob=0.1, downscale_freq_shift=0, use_cfg_embedding=True):
35
+ super(CombinedTimestepLabelEmbeddings, self).__init__()
36
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=downscale_freq_shift)
37
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
38
+ self.class_embedder = LabelEmbeddingMod(num_classes, embedding_dim, class_dropout_prob, use_cfg_embedding)
39
+
40
+
41
+ class BasicTransformerBlockMod(BasicTransformerBlock):
42
+ """
43
+ Modified BasicTransformerBlock for reproducing the original DiT with shared time and class
44
+ embeddings across all layers.
45
+ """
46
+ def __init__(
47
+ self,
48
+ dim: int,
49
+ num_attention_heads: int,
50
+ attention_head_dim: int,
51
+ dropout=0.0,
52
+ cross_attention_dim: Optional[int] = None,
53
+ activation_fn: str = 'geglu',
54
+ num_embeds_ada_norm: Optional[int] = None,
55
+ attention_bias: bool = False,
56
+ only_cross_attention: bool = False,
57
+ double_self_attention: bool = False,
58
+ upcast_attention: bool = False,
59
+ norm_elementwise_affine: bool = True,
60
+ norm_type: str = 'layer_norm',
61
+ norm_eps: float = 1e-5,
62
+ final_dropout: bool = False,
63
+ attention_type: str = 'default',
64
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
65
+ ada_norm_bias: Optional[int] = None,
66
+ ff_inner_dim: Optional[int] = None,
67
+ ff_bias: bool = True,
68
+ attention_out_bias: bool = True):
69
+ super(BasicTransformerBlock, self).__init__()
70
+ self.only_cross_attention = only_cross_attention
71
+ self.norm_type = norm_type
72
+ self.num_embeds_ada_norm = num_embeds_ada_norm
73
+
74
+ assert self.norm_type == 'ada_norm_zero'
75
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
76
+ self.attn1 = Attention(
77
+ query_dim=dim,
78
+ heads=num_attention_heads,
79
+ dim_head=attention_head_dim,
80
+ dropout=dropout,
81
+ bias=attention_bias,
82
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
83
+ upcast_attention=upcast_attention,
84
+ out_bias=attention_out_bias,
85
+ )
86
+
87
+ self.norm2 = None
88
+ self.attn2 = None
89
+
90
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
91
+ self.ff = FeedForward(
92
+ dim,
93
+ dropout=dropout,
94
+ activation_fn=activation_fn,
95
+ final_dropout=final_dropout,
96
+ inner_dim=ff_inner_dim,
97
+ bias=ff_bias,
98
+ )
99
+
100
+ self._chunk_size = None
101
+ self._chunk_dim = 0
102
+
103
+ def forward(
104
+ self,
105
+ hidden_states: torch.Tensor,
106
+ attention_mask: Optional[torch.Tensor] = None,
107
+ encoder_hidden_states: Optional[torch.Tensor] = None,
108
+ encoder_attention_mask: Optional[torch.Tensor] = None,
109
+ timestep: Optional[torch.LongTensor] = None,
110
+ cross_attention_kwargs: Dict[str, Any] = None,
111
+ class_labels: Optional[torch.LongTensor] = None,
112
+ emb: Optional[torch.Tensor] = None,
113
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
114
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
115
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype, emb=emb)
116
+
117
+ if cross_attention_kwargs is None:
118
+ cross_attention_kwargs = dict()
119
+ attn_output = self.attn1(
120
+ norm_hidden_states,
121
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
122
+ attention_mask=attention_mask,
123
+ **cross_attention_kwargs)
124
+ attn_output = gate_msa.unsqueeze(1) * attn_output
125
+
126
+ hidden_states = attn_output + hidden_states
127
+ if hidden_states.ndim == 4:
128
+ hidden_states = hidden_states.squeeze(1)
129
+
130
+ norm_hidden_states = self.norm3(hidden_states)
131
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
132
+
133
+ if self._chunk_size is not None:
134
+ # "feed_forward_chunk_size" can be used to save memory
135
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
136
+ else:
137
+ ff_output = self.ff(norm_hidden_states)
138
+
139
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
140
+
141
+ hidden_states = ff_output + hidden_states
142
+ if hidden_states.ndim == 4:
143
+ hidden_states = hidden_states.squeeze(1)
144
+
145
+ return hidden_states
146
+
147
+
148
+ class _DiTTransformer2DModelMod(DiTTransformer2DModel):
149
+
150
+ @register_to_config
151
+ def __init__(
152
+ self,
153
+ class_dropout_prob=0.0,
154
+ num_attention_heads: int = 16,
155
+ attention_head_dim: int = 72,
156
+ in_channels: int = 4,
157
+ out_channels: Optional[int] = None,
158
+ num_layers: int = 28,
159
+ dropout: float = 0.0,
160
+ norm_num_groups: int = 32,
161
+ attention_bias: bool = True,
162
+ sample_size: int = 32,
163
+ patch_size: int = 2,
164
+ activation_fn: str = 'gelu-approximate',
165
+ num_embeds_ada_norm: Optional[int] = 1000,
166
+ upcast_attention: bool = False,
167
+ norm_type: str = 'ada_norm_zero',
168
+ norm_elementwise_affine: bool = False,
169
+ norm_eps: float = 1e-5):
170
+
171
+ super(DiTTransformer2DModel, self).__init__()
172
+
173
+ # Validate inputs.
174
+ if norm_type != "ada_norm_zero":
175
+ raise NotImplementedError(
176
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
177
+ )
178
+ elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None:
179
+ raise ValueError(
180
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
181
+ )
182
+
183
+ # Set some common variables used across the board.
184
+ self.attention_head_dim = attention_head_dim
185
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
186
+ self.out_channels = in_channels if out_channels is None else out_channels
187
+ self.gradient_checkpointing = False
188
+
189
+ # 2. Initialize the position embedding and transformer blocks.
190
+ self.height = self.config.sample_size
191
+ self.width = self.config.sample_size
192
+
193
+ self.patch_size = self.config.patch_size
194
+ self.pos_embed = PatchEmbed(
195
+ height=self.config.sample_size,
196
+ width=self.config.sample_size,
197
+ patch_size=self.config.patch_size,
198
+ in_channels=self.config.in_channels,
199
+ embed_dim=self.inner_dim)
200
+ self.emb = CombinedTimestepLabelEmbeddingsMod(
201
+ num_embeds_ada_norm, self.inner_dim, class_dropout_prob=0.0)
202
+
203
+ self.transformer_blocks = nn.ModuleList([
204
+ BasicTransformerBlockMod(
205
+ self.inner_dim,
206
+ self.config.num_attention_heads,
207
+ self.config.attention_head_dim,
208
+ dropout=self.config.dropout,
209
+ activation_fn=self.config.activation_fn,
210
+ num_embeds_ada_norm=None,
211
+ attention_bias=self.config.attention_bias,
212
+ upcast_attention=self.config.upcast_attention,
213
+ norm_type=norm_type,
214
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
215
+ norm_eps=self.config.norm_eps)
216
+ for _ in range(self.config.num_layers)])
217
+
218
+ # 3. Output blocks.
219
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
220
+ self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
221
+ self.proj_out_2 = nn.Linear(
222
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
223
+
224
+ # https://github.com/facebookresearch/DiT/blob/main/models.py
225
+ def init_weights(self):
226
+ for m in self.modules():
227
+ if isinstance(m, nn.Linear):
228
+ xavier_init(m, distribution='uniform')
229
+ elif isinstance(m, nn.Embedding):
230
+ torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
231
+
232
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d)
233
+ w = self.pos_embed.proj.weight.data
234
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
235
+ nn.init.constant_(self.pos_embed.proj.bias, 0)
236
+
237
+ # Zero-out adaLN modulation layers in DiT blocks
238
+ for m in self.modules():
239
+ if isinstance(m, AdaLayerNormZero):
240
+ constant_init(m.linear, val=0)
241
+
242
+ # Zero-out output layers
243
+ constant_init(self.proj_out_1, val=0)
244
+ constant_init(self.proj_out_2, val=0)
245
+
246
+ def forward(
247
+ self,
248
+ hidden_states: torch.Tensor,
249
+ timestep: Optional[torch.LongTensor] = None,
250
+ class_labels: Optional[torch.LongTensor] = None,
251
+ cross_attention_kwargs: Dict[str, Any] = None):
252
+ # 1. Input
253
+ bs, _, h, w = hidden_states.size()
254
+ height, width = h // self.patch_size, w // self.patch_size
255
+ hidden_states = self.pos_embed(hidden_states)
256
+
257
+ cond_emb = self.emb(
258
+ timestep, class_labels, hidden_dtype=hidden_states.dtype)
259
+ dropout_enabled = self.config.class_dropout_prob > 0 and self.training
260
+ if dropout_enabled:
261
+ uncond_emb = self.emb(timestep, torch.full_like(
262
+ class_labels, self.config.num_embeds_ada_norm), hidden_dtype=hidden_states.dtype)
263
+
264
+ # 2. Blocks
265
+ for block in self.transformer_blocks:
266
+ if dropout_enabled:
267
+ dropout_mask = torch.rand((bs, 1), device=hidden_states.device) < self.config.class_dropout_prob
268
+ emb = torch.where(dropout_mask, uncond_emb, cond_emb)
269
+ else:
270
+ emb = cond_emb
271
+
272
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
273
+
274
+ def create_custom_forward(module, return_dict=None):
275
+ def custom_forward(*inputs):
276
+ if return_dict is not None:
277
+ return module(*inputs, return_dict=return_dict)
278
+ else:
279
+ return module(*inputs)
280
+
281
+ return custom_forward
282
+
283
+ hidden_states = torch.utils.checkpoint.checkpoint(
284
+ create_custom_forward(block),
285
+ hidden_states,
286
+ None,
287
+ None,
288
+ None,
289
+ timestep,
290
+ cross_attention_kwargs,
291
+ class_labels,
292
+ emb,
293
+ use_reentrant=False)
294
+
295
+ else:
296
+ hidden_states = block(
297
+ hidden_states,
298
+ attention_mask=None,
299
+ encoder_hidden_states=None,
300
+ encoder_attention_mask=None,
301
+ timestep=timestep,
302
+ cross_attention_kwargs=cross_attention_kwargs,
303
+ class_labels=class_labels,
304
+ emb=emb)
305
+
306
+ # 3. Output
307
+ if dropout_enabled:
308
+ dropout_mask = torch.rand((bs, 1), device=hidden_states.device) < self.config.class_dropout_prob
309
+ emb = torch.where(dropout_mask, uncond_emb, cond_emb)
310
+ else:
311
+ emb = cond_emb
312
+ shift, scale = self.proj_out_1(F.silu(emb)).chunk(2, dim=1)
313
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
314
+ output = self.proj_out_2(hidden_states).reshape(
315
+ bs, height, width, self.patch_size, self.patch_size, self.out_channels
316
+ ).permute(0, 5, 1, 3, 2, 4).reshape(
317
+ bs, self.out_channels, height * self.patch_size, width * self.patch_size)
318
+
319
+ return output
320
+
321
+
322
+ @MODULES.register_module()
323
+ class DiTTransformer2DModelMod(_DiTTransformer2DModelMod):
324
+
325
+ def __init__(
326
+ self,
327
+ *args,
328
+ freeze=False,
329
+ freeze_exclude=[],
330
+ pretrained=None,
331
+ torch_dtype='float32',
332
+ autocast_dtype=None,
333
+ freeze_exclude_fp32=True,
334
+ freeze_exclude_autocast_dtype='float32',
335
+ checkpointing=True,
336
+ **kwargs):
337
+ super().__init__(*args, **kwargs)
338
+
339
+ self.init_weights(pretrained)
340
+
341
+ if autocast_dtype is not None:
342
+ assert torch_dtype == 'float32'
343
+ self.autocast_dtype = autocast_dtype
344
+
345
+ if torch_dtype is not None:
346
+ self.to(getattr(torch, torch_dtype))
347
+
348
+ self.freeze = freeze
349
+ if self.freeze:
350
+ flex_freeze(
351
+ self,
352
+ exclude_keys=freeze_exclude,
353
+ exclude_fp32=freeze_exclude_fp32,
354
+ exclude_autocast_dtype=freeze_exclude_autocast_dtype)
355
+
356
+ if checkpointing:
357
+ self.enable_gradient_checkpointing()
358
+
359
+ def init_weights(self, pretrained=None):
360
+ super().init_weights()
361
+ if pretrained is not None:
362
+ logger = get_root_logger()
363
+ # load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
364
+ checkpoint = _load_checkpoint(pretrained, map_location='cpu', logger=logger)
365
+ if 'state_dict' in checkpoint:
366
+ state_dict = checkpoint['state_dict']
367
+ else:
368
+ state_dict = checkpoint
369
+ # load from GMDiT V1 model with 1 Gaussian
370
+ p2 = self.config.patch_size * self.config.patch_size
371
+ ori_out_channels = p2 * self.out_channels
372
+ if 'proj_out_2.weight' in state_dict:
373
+ # if this is GMDiT V1 model with 1 Gaussian
374
+ if state_dict['proj_out_2.weight'].size(0) == p2 * (self.out_channels + 1):
375
+ state_dict['proj_out_2.weight'] = state_dict['proj_out_2.weight'].reshape(
376
+ p2, self.out_channels + 1, -1
377
+ )[:, :-1].reshape(ori_out_channels, -1)
378
+ # if this is original DiT with variance prediction
379
+ if state_dict['proj_out_2.weight'].size(0) == 2 * ori_out_channels:
380
+ state_dict['proj_out_2.weight'] = state_dict['proj_out_2.weight'].reshape(
381
+ p2, 2 * self.out_channels, -1
382
+ )[:, :self.out_channels].reshape(ori_out_channels, -1)
383
+ if 'proj_out_2.bias' in state_dict:
384
+ # if this is GMDiT V1 model with 1 Gaussian
385
+ if state_dict['proj_out_2.bias'].size(0) == p2 * (self.out_channels + 1):
386
+ state_dict['proj_out_2.bias'] = state_dict['proj_out_2.bias'].reshape(
387
+ p2, self.out_channels + 1
388
+ )[:, :-1].reshape(ori_out_channels)
389
+ # if this is original DiT with variance prediction
390
+ if state_dict['proj_out_2.bias'].size(0) == 2 * ori_out_channels:
391
+ state_dict['proj_out_2.bias'] = state_dict['proj_out_2.bias'].reshape(
392
+ p2, 2 * self.out_channels
393
+ )[:, :self.out_channels].reshape(ori_out_channels)
394
+ if 'emb.class_embedder.embedding_table.weight' not in state_dict \
395
+ and 'transformer_blocks.0.norm1.emb.class_embedder.embedding_table.weight' in state_dict:
396
+ # convert original diffusers DiT model to our modified DiT model with shared embeddings
397
+ keys_to_remove = []
398
+ state_update = {}
399
+ for k, v in state_dict.items():
400
+ if k.startswith('transformer_blocks.0.norm1.emb.'):
401
+ new_k = k.replace('transformer_blocks.0.norm1.', '')
402
+ state_update[new_k] = v
403
+ if k.startswith('transformer_blocks.') and '.norm1.emb.' in k:
404
+ keys_to_remove.append(k)
405
+ state_dict.update(state_update)
406
+ for k in keys_to_remove:
407
+ del state_dict[k]
408
+ load_state_dict(self, state_dict, logger=logger)
409
+
410
+ def forward(
411
+ self,
412
+ hidden_states: torch.Tensor,
413
+ timestep: Optional[torch.LongTensor] = None,
414
+ class_labels: Optional[torch.LongTensor] = None,
415
+ **kwargs):
416
+ if self.autocast_dtype is not None:
417
+ dtype = getattr(torch, self.autocast_dtype)
418
+ else:
419
+ dtype = hidden_states.dtype
420
+ with torch.autocast(
421
+ device_type='cuda',
422
+ enabled=self.autocast_dtype is not None,
423
+ dtype=dtype if self.autocast_dtype is not None else None):
424
+ return super().forward(
425
+ hidden_states.to(dtype),
426
+ timestep=timestep,
427
+ class_labels=class_labels,
428
+ **kwargs)
models/lakonlab/models/architecture/diffusers/flux.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing import Optional
4
+ from accelerate import init_empty_weights
5
+ from diffusers.models import FluxTransformer2DModel as _FluxTransformer2DModel
6
+ from peft import LoraConfig
7
+ from mmgen.models.builder import MODULES
8
+ from mmgen.utils import get_root_logger
9
+ from ..utils import flex_freeze
10
+ from lakonlab.runner.checkpoint import load_checkpoint, _load_checkpoint
11
+
12
+
13
+ @MODULES.register_module()
14
+ class FluxTransformer2DModel(_FluxTransformer2DModel):
15
+
16
+ def __init__(
17
+ self,
18
+ *args,
19
+ patch_size=2,
20
+ freeze=False,
21
+ freeze_exclude=[],
22
+ pretrained=None,
23
+ pretrained_lora=None,
24
+ pretrained_lora_scale=1.0,
25
+ torch_dtype='float32',
26
+ freeze_exclude_fp32=True,
27
+ freeze_exclude_autocast_dtype='float32',
28
+ checkpointing=True,
29
+ use_lora=False,
30
+ lora_target_modules=None,
31
+ lora_rank=16,
32
+ **kwargs):
33
+ with init_empty_weights():
34
+ super().__init__(patch_size=1, *args, **kwargs)
35
+ self.patch_size = patch_size
36
+
37
+ self.init_weights(pretrained, pretrained_lora, pretrained_lora_scale)
38
+
39
+ self.use_lora = use_lora
40
+ self.lora_target_modules = lora_target_modules
41
+ self.lora_rank = lora_rank
42
+ if self.use_lora:
43
+ transformer_lora_config = LoraConfig(
44
+ r=lora_rank,
45
+ lora_alpha=lora_rank,
46
+ init_lora_weights='gaussian',
47
+ target_modules=lora_target_modules,
48
+ )
49
+ self.add_adapter(transformer_lora_config)
50
+
51
+ if torch_dtype is not None:
52
+ self.to(getattr(torch, torch_dtype))
53
+
54
+ self.freeze = freeze
55
+ if self.freeze:
56
+ flex_freeze(
57
+ self,
58
+ exclude_keys=freeze_exclude,
59
+ exclude_fp32=freeze_exclude_fp32,
60
+ exclude_autocast_dtype=freeze_exclude_autocast_dtype)
61
+
62
+ if checkpointing:
63
+ self.enable_gradient_checkpointing()
64
+
65
+ def init_weights(self, pretrained=None, pretrained_lora=None, pretrained_lora_scale=1.0):
66
+ if pretrained is not None:
67
+ logger = get_root_logger()
68
+ load_checkpoint(
69
+ self, pretrained, map_location='cpu', strict=False, logger=logger, assign=True)
70
+ if pretrained_lora is not None:
71
+ if not isinstance(pretrained_lora, (list, tuple)):
72
+ assert isinstance(pretrained_lora, str)
73
+ pretrained_lora = [pretrained_lora]
74
+ if not isinstance(pretrained_lora_scale, (list, tuple)):
75
+ assert isinstance(pretrained_lora_scale, (int, float))
76
+ pretrained_lora_scale = [pretrained_lora_scale]
77
+ for pretrained_lora_single, pretrained_lora_scale_single in zip(pretrained_lora, pretrained_lora_scale):
78
+ lora_state_dict = _load_checkpoint(
79
+ pretrained_lora_single, map_location='cpu', logger=logger)
80
+ self.load_lora_adapter(lora_state_dict)
81
+ self.fuse_lora(lora_scale=pretrained_lora_scale_single)
82
+ self.unload_lora()
83
+
84
+ @staticmethod
85
+ def _prepare_latent_image_ids(height, width, device, dtype):
86
+ """
87
+ Copied from Diffusers
88
+ """
89
+ latent_image_ids = torch.zeros(height, width, 3)
90
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
91
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
92
+
93
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
94
+
95
+ latent_image_ids = latent_image_ids.reshape(
96
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels)
97
+
98
+ return latent_image_ids.to(device=device, dtype=dtype)
99
+
100
+ def patchify(self, latents):
101
+ if self.patch_size > 1:
102
+ bs, c, h, w = latents.size()
103
+ latents = latents.reshape(
104
+ bs, c, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size
105
+ ).permute(
106
+ 0, 1, 3, 5, 2, 4
107
+ ).reshape(
108
+ bs, c * self.patch_size * self.patch_size, h // self.patch_size, w // self.patch_size)
109
+ return latents
110
+
111
+ def unpatchify(self, latents):
112
+ if self.patch_size > 1:
113
+ bs, c, h, w = latents.size()
114
+ latents = latents.reshape(
115
+ bs, c // (self.patch_size * self.patch_size), self.patch_size, self.patch_size, h, w
116
+ ).permute(
117
+ 0, 1, 4, 2, 5, 3
118
+ ).reshape(
119
+ bs, c // (self.patch_size * self.patch_size), h * self.patch_size, w * self.patch_size)
120
+ return latents
121
+
122
+ def forward(
123
+ self,
124
+ hidden_states: torch.Tensor,
125
+ timestep: torch.Tensor,
126
+ encoder_hidden_states: torch.Tensor = None,
127
+ pooled_projections: torch.Tensor = None,
128
+ mask: Optional[torch.Tensor] = None,
129
+ masked_image_latents: Optional[torch.Tensor] = None,
130
+ **kwargs):
131
+ hidden_states = self.patchify(hidden_states)
132
+ bs, c, h, w = hidden_states.size()
133
+ dtype = hidden_states.dtype
134
+ device = hidden_states.device
135
+ hidden_states = hidden_states.reshape(bs, c, h * w).permute(0, 2, 1)
136
+ img_ids = self._prepare_latent_image_ids(
137
+ h, w, device, dtype)
138
+ txt_ids = img_ids.new_zeros((encoder_hidden_states.shape[-2], 3))
139
+
140
+ # Flux fill
141
+ if mask is not None and masked_image_latents is not None:
142
+ hidden_states = torch.cat(
143
+ (hidden_states, masked_image_latents.to(dtype=dtype), mask.to(dtype=dtype)), dim=-1)
144
+
145
+ output = super().forward(
146
+ hidden_states=hidden_states,
147
+ encoder_hidden_states=encoder_hidden_states.to(dtype),
148
+ pooled_projections=pooled_projections.to(dtype),
149
+ timestep=timestep,
150
+ img_ids=img_ids,
151
+ txt_ids=txt_ids,
152
+ return_dict=False,
153
+ **kwargs)[0]
154
+
155
+ output = output.permute(0, 2, 1).reshape(bs, self.out_channels, h, w)
156
+ return self.unpatchify(output)
models/lakonlab/models/architecture/diffusers/pretrained.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from diffusers.models import AutoencoderKL, AutoencoderKLQwenImage
7
+ from diffusers.pipelines import FluxPipeline, QwenImagePipeline, StableDiffusion3Pipeline
8
+ from mmgen.models.builder import MODULES
9
+
10
+ # Suppress truncation warnings from transformers and diffusers
11
+ for name in (
12
+ 'transformers.tokenization_utils_base',
13
+ 'transformers.tokenization_utils',
14
+ 'transformers.tokenization_utils_fast'):
15
+ logging.getLogger(name).setLevel(logging.ERROR)
16
+
17
+ for name, logger in logging.root.manager.loggerDict.items():
18
+ if isinstance(logger, logging.Logger) and (name.startswith('diffusers.pipelines.')):
19
+ logger.setLevel(logging.ERROR)
20
+
21
+
22
+ @MODULES.register_module()
23
+ class PretrainedVAE(nn.Module):
24
+ def __init__(self,
25
+ from_pretrained=None,
26
+ del_encoder=False,
27
+ del_decoder=False,
28
+ use_slicing=False,
29
+ freeze=True,
30
+ eval_mode=True,
31
+ torch_dtype='float32',
32
+ **kwargs):
33
+ super().__init__()
34
+ if torch_dtype is not None:
35
+ kwargs.update(torch_dtype=getattr(torch, torch_dtype))
36
+ self.vae = AutoencoderKL.from_pretrained(
37
+ from_pretrained, **kwargs)
38
+ if del_encoder:
39
+ del self.vae.encoder
40
+ if del_decoder:
41
+ del self.vae.decoder
42
+ if use_slicing:
43
+ self.vae.enable_slicing()
44
+ self.freeze = freeze
45
+ self.eval_mode = eval_mode
46
+ if self.freeze:
47
+ self.requires_grad_(False)
48
+ if self.eval_mode:
49
+ self.eval()
50
+ self.vae.set_use_memory_efficient_attention_xformers(
51
+ not hasattr(torch.nn.functional, 'scaled_dot_product_attention'))
52
+
53
+ def train(self, mode=True):
54
+ mode = mode and (not self.eval_mode)
55
+ return super().train(mode)
56
+
57
+ def forward(self, *args, **kwargs):
58
+ return self.vae(*args, return_dict=False, **kwargs)[0]
59
+
60
+ def encode(self, img):
61
+ scaling_factor = self.vae.config.scaling_factor
62
+ shift_factor = self.vae.config.shift_factor
63
+ if scaling_factor is None:
64
+ scaling_factor = 1.0
65
+ if shift_factor is None:
66
+ shift_factor = 0.0
67
+ return (self.vae.encode(img).latent_dist.sample() - shift_factor) * scaling_factor
68
+
69
+ def decode(self, code):
70
+ scaling_factor = self.vae.config.scaling_factor
71
+ shift_factor = self.vae.config.shift_factor
72
+ if scaling_factor is None:
73
+ scaling_factor = 1.0
74
+ if shift_factor is None:
75
+ shift_factor = 0.0
76
+ return self.vae.decode(code / scaling_factor + shift_factor, return_dict=False)[0]
77
+
78
+
79
+ @MODULES.register_module()
80
+ class PretrainedVAEDecoder(PretrainedVAE):
81
+ def __init__(self, **kwargs):
82
+ super().__init__(
83
+ del_encoder=True,
84
+ del_decoder=False,
85
+ **kwargs)
86
+
87
+ def forward(self, code):
88
+ return super().decode(code)
89
+
90
+
91
+ @MODULES.register_module()
92
+ class PretrainedVAEEncoder(PretrainedVAE):
93
+ def __init__(self, **kwargs):
94
+ super().__init__(
95
+ del_encoder=False,
96
+ del_decoder=True,
97
+ **kwargs)
98
+
99
+ def forward(self, img):
100
+ return super().encode(img)
101
+
102
+
103
+ @MODULES.register_module()
104
+ class PretrainedVAEQwenImage(nn.Module):
105
+ def __init__(self,
106
+ from_pretrained=None,
107
+ use_slicing=False,
108
+ freeze=True,
109
+ eval_mode=True,
110
+ torch_dtype='float32',
111
+ **kwargs):
112
+ super().__init__()
113
+ if torch_dtype is not None:
114
+ kwargs.update(torch_dtype=getattr(torch, torch_dtype))
115
+ self.vae = AutoencoderKLQwenImage.from_pretrained(
116
+ from_pretrained, **kwargs)
117
+ if use_slicing:
118
+ self.vae.enable_slicing()
119
+ self.freeze = freeze
120
+ self.eval_mode = eval_mode
121
+ if self.freeze:
122
+ self.requires_grad_(False)
123
+ if self.eval_mode:
124
+ self.eval()
125
+
126
+ def train(self, mode=True):
127
+ mode = mode and (not self.eval_mode)
128
+ return super().train(mode)
129
+
130
+ def forward(self, *args, **kwargs):
131
+ return self.vae(*args, return_dict=False, **kwargs)[0]
132
+
133
+ def encode(self, img):
134
+ device = img.device
135
+ dtype = img.dtype
136
+ latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=dtype).view(
137
+ 1, self.vae.config.z_dim, 1, 1, 1)
138
+ latents_std = torch.tensor(self.vae.config.latents_std, device=device, dtype=dtype).view(
139
+ 1, self.vae.config.z_dim, 1, 1, 1)
140
+ return ((self.vae.encode(img.unsqueeze(-3)).latent_dist.sample() - latents_mean) / latents_std).squeeze(-3)
141
+
142
+ def decode(self, code):
143
+ device = code.device
144
+ dtype = code.dtype
145
+ latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=dtype).view(
146
+ 1, self.vae.config.z_dim, 1, 1, 1)
147
+ latents_std = torch.tensor(self.vae.config.latents_std, device=device, dtype=dtype).view(
148
+ 1, self.vae.config.z_dim, 1, 1, 1)
149
+ return self.vae.decode(code.unsqueeze(-3) * latents_std + latents_mean, return_dict=False)[0].squeeze(-3)
150
+
151
+
152
+ @MODULES.register_module()
153
+ class PretrainedFluxTextEncoder(nn.Module):
154
+ def __init__(self,
155
+ from_pretrained='black-forest-labs/FLUX.1-dev',
156
+ freeze=True,
157
+ eval_mode=True,
158
+ torch_dtype='bfloat16',
159
+ max_sequence_length=512,
160
+ **kwargs):
161
+ super().__init__()
162
+ self.max_sequence_length = max_sequence_length
163
+ self.pipeline = FluxPipeline.from_pretrained(
164
+ from_pretrained,
165
+ scheduler=None,
166
+ vae=None,
167
+ transformer=None,
168
+ image_encoder=None,
169
+ feature_extractor=None,
170
+ torch_dtype=getattr(torch, torch_dtype),
171
+ **kwargs)
172
+ self.text_encoder = self.pipeline.text_encoder
173
+ self.text_encoder_2 = self.pipeline.text_encoder_2
174
+ self.freeze = freeze
175
+ self.eval_mode = eval_mode
176
+ if self.freeze:
177
+ self.requires_grad_(False)
178
+ if self.eval_mode:
179
+ self.eval()
180
+
181
+ def train(self, mode=True):
182
+ mode = mode and (not self.eval_mode)
183
+ return super().train(mode)
184
+
185
+ def forward(self, prompt, prompt_2=None):
186
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.pipeline.encode_prompt(
187
+ prompt, prompt_2=prompt_2, max_sequence_length=self.max_sequence_length)
188
+ return dict(
189
+ encoder_hidden_states=prompt_embeds,
190
+ pooled_projections=pooled_prompt_embeds)
191
+
192
+
193
+ @MODULES.register_module()
194
+ class PretrainedQwenImageTextEncoder(nn.Module):
195
+ def __init__(self,
196
+ from_pretrained='Qwen/Qwen-Image',
197
+ freeze=True,
198
+ eval_mode=True,
199
+ torch_dtype='bfloat16',
200
+ max_sequence_length=512,
201
+ pad_seq_len=None,
202
+ **kwargs):
203
+ super().__init__()
204
+ self.max_sequence_length = max_sequence_length
205
+ if pad_seq_len is not None:
206
+ assert pad_seq_len >= max_sequence_length
207
+ self.pad_seq_len = pad_seq_len
208
+ self.pipeline = QwenImagePipeline.from_pretrained(
209
+ from_pretrained,
210
+ scheduler=None,
211
+ vae=None,
212
+ transformer=None,
213
+ torch_dtype=getattr(torch, torch_dtype),
214
+ **kwargs)
215
+ self.text_encoder = self.pipeline.text_encoder
216
+ self.freeze = freeze
217
+ self.eval_mode = eval_mode
218
+ if self.freeze:
219
+ self.requires_grad_(False)
220
+ if self.eval_mode:
221
+ self.eval()
222
+
223
+ def train(self, mode=True):
224
+ mode = mode and (not self.eval_mode)
225
+ return super().train(mode)
226
+
227
+ def forward(self, prompt):
228
+ prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
229
+ prompt, max_sequence_length=self.max_sequence_length)
230
+ if self.pad_seq_len is not None:
231
+ pad_len = self.pad_seq_len - prompt_embeds.size(1)
232
+ prompt_embeds = F.pad(
233
+ prompt_embeds, (0, 0, 0, pad_len), value=0.0)
234
+ prompt_embeds_mask = F.pad(
235
+ prompt_embeds_mask, (0, pad_len), value=0.0)
236
+ return dict(
237
+ encoder_hidden_states=prompt_embeds,
238
+ encoder_hidden_states_mask=prompt_embeds_mask)
239
+
240
+
241
+ @MODULES.register_module()
242
+ class PretrainedStableDiffusion3TextEncoder(nn.Module):
243
+ def __init__(self,
244
+ from_pretrained='stabilityai/stable-diffusion-3.5-large',
245
+ freeze=True,
246
+ eval_mode=True,
247
+ torch_dtype='float32',
248
+ max_sequence_length=256,
249
+ **kwargs):
250
+ super().__init__()
251
+ self.max_sequence_length = max_sequence_length
252
+ self.pipeline = StableDiffusion3Pipeline.from_pretrained(
253
+ from_pretrained,
254
+ scheduler=None,
255
+ vae=None,
256
+ transformer=None,
257
+ image_encoder=None,
258
+ feature_extractor=None,
259
+ torch_dtype=getattr(torch, torch_dtype),
260
+ **kwargs)
261
+ self.text_encoder = self.pipeline.text_encoder
262
+ self.text_encoder_2 = self.pipeline.text_encoder_2
263
+ self.text_encoder_3 = self.pipeline.text_encoder_3
264
+ self.freeze = freeze
265
+ self.eval_mode = eval_mode
266
+ if self.freeze:
267
+ self.requires_grad_(False)
268
+ if self.eval_mode:
269
+ self.eval()
270
+
271
+ def train(self, mode=True):
272
+ mode = mode and (not self.eval_mode)
273
+ return super().train(mode)
274
+
275
+ def forward(self, prompt, prompt_2=None, prompt_3=None):
276
+ prompt_embeds, _, pooled_prompt_embeds, _ = self.pipeline.encode_prompt(
277
+ prompt, prompt_2=prompt_2, prompt_3=prompt_3, do_classifier_free_guidance=False,
278
+ max_sequence_length=self.max_sequence_length)
279
+ return dict(
280
+ encoder_hidden_states=prompt_embeds,
281
+ pooled_projections=pooled_prompt_embeds)
models/lakonlab/models/architecture/diffusers/qwen.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from accelerate import init_empty_weights
4
+ from diffusers.models import QwenImageTransformer2DModel as _QwenImageTransformer2DModel
5
+ from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_qwen_lora_to_diffusers
6
+ from peft import LoraConfig
7
+ from mmgen.models.builder import MODULES
8
+ from mmgen.utils import get_root_logger
9
+ from ..utils import flex_freeze
10
+ from lakonlab.runner.checkpoint import load_checkpoint, _load_checkpoint
11
+
12
+
13
+ @MODULES.register_module()
14
+ class QwenImageTransformer2DModel(_QwenImageTransformer2DModel):
15
+
16
+ def __init__(
17
+ self,
18
+ *args,
19
+ patch_size=2,
20
+ freeze=False,
21
+ freeze_exclude=[],
22
+ pretrained=None,
23
+ pretrained_lora=None,
24
+ pretrained_lora_scale=1.0,
25
+ torch_dtype='float32',
26
+ freeze_exclude_fp32=True,
27
+ freeze_exclude_autocast_dtype='float32',
28
+ checkpointing=True,
29
+ use_lora=False,
30
+ lora_target_modules=None,
31
+ lora_rank=16,
32
+ **kwargs):
33
+ with init_empty_weights():
34
+ super().__init__(*args, patch_size=1, **kwargs)
35
+ self.patch_size = patch_size
36
+
37
+ self.init_weights(pretrained, pretrained_lora, pretrained_lora_scale)
38
+
39
+ self.use_lora = use_lora
40
+ self.lora_target_modules = lora_target_modules
41
+ self.lora_rank = lora_rank
42
+ if self.use_lora:
43
+ transformer_lora_config = LoraConfig(
44
+ r=lora_rank,
45
+ lora_alpha=lora_rank,
46
+ init_lora_weights='gaussian',
47
+ target_modules=lora_target_modules,
48
+ )
49
+ self.add_adapter(transformer_lora_config)
50
+
51
+ if torch_dtype is not None:
52
+ self.to(getattr(torch, torch_dtype))
53
+
54
+ self.freeze = freeze
55
+ if self.freeze:
56
+ flex_freeze(
57
+ self,
58
+ exclude_keys=freeze_exclude,
59
+ exclude_fp32=freeze_exclude_fp32,
60
+ exclude_autocast_dtype=freeze_exclude_autocast_dtype)
61
+
62
+ if checkpointing:
63
+ self.enable_gradient_checkpointing()
64
+
65
+ def init_weights(self, pretrained=None, pretrained_lora=None, pretrained_lora_scale=1.0):
66
+ if pretrained is not None:
67
+ logger = get_root_logger()
68
+ load_checkpoint(
69
+ self, pretrained, map_location='cpu', strict=False, logger=logger, assign=True)
70
+ if pretrained_lora is not None:
71
+ if not isinstance(pretrained_lora, (list, tuple)):
72
+ assert isinstance(pretrained_lora, str)
73
+ pretrained_lora = [pretrained_lora]
74
+ if not isinstance(pretrained_lora_scale, (list, tuple)):
75
+ assert isinstance(pretrained_lora_scale, (int, float))
76
+ pretrained_lora_scale = [pretrained_lora_scale]
77
+ for pretrained_lora_single, pretrained_lora_scale_single in zip(pretrained_lora, pretrained_lora_scale):
78
+ lora_state_dict = _load_checkpoint(
79
+ pretrained_lora_single, map_location='cpu', logger=logger)
80
+ lora_state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(lora_state_dict)
81
+ self.load_lora_adapter(lora_state_dict)
82
+ self.fuse_lora(lora_scale=pretrained_lora_scale_single)
83
+ self.unload_lora()
84
+
85
+ def patchify(self, latents):
86
+ if self.patch_size > 1:
87
+ bs, c, h, w = latents.size()
88
+ latents = latents.reshape(
89
+ bs, c, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size
90
+ ).permute(
91
+ 0, 1, 3, 5, 2, 4
92
+ ).reshape(
93
+ bs, c * self.patch_size * self.patch_size, h // self.patch_size, w // self.patch_size)
94
+ return latents
95
+
96
+ def unpatchify(self, latents):
97
+ if self.patch_size > 1:
98
+ bs, c, h, w = latents.size()
99
+ latents = latents.reshape(
100
+ bs, c // (self.patch_size * self.patch_size), self.patch_size, self.patch_size, h, w
101
+ ).permute(
102
+ 0, 1, 4, 2, 5, 3
103
+ ).reshape(
104
+ bs, c // (self.patch_size * self.patch_size), h * self.patch_size, w * self.patch_size)
105
+ return latents
106
+
107
+ def forward(
108
+ self,
109
+ hidden_states: torch.Tensor,
110
+ timestep: torch.Tensor,
111
+ encoder_hidden_states: torch.Tensor = None,
112
+ encoder_hidden_states_mask: torch.Tensor = None,
113
+ **kwargs):
114
+ hidden_states = self.patchify(hidden_states)
115
+ bs, c, h, w = hidden_states.size()
116
+ dtype = hidden_states.dtype
117
+ hidden_states = hidden_states.reshape(bs, c, h * w).permute(0, 2, 1)
118
+ img_shapes = [[(1, h, w)]]
119
+ if encoder_hidden_states_mask is not None:
120
+ txt_seq_lens = encoder_hidden_states_mask.sum(dim=1)
121
+ max_txt_seq_len = txt_seq_lens.max()
122
+ encoder_hidden_states = encoder_hidden_states[:, :max_txt_seq_len]
123
+ encoder_hidden_states_mask = encoder_hidden_states_mask[:, :max_txt_seq_len]
124
+ txt_seq_lens = txt_seq_lens.tolist()
125
+ else:
126
+ txt_seq_lens = None
127
+
128
+ output = super().forward(
129
+ hidden_states=hidden_states,
130
+ encoder_hidden_states=encoder_hidden_states.to(dtype),
131
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
132
+ timestep=timestep,
133
+ img_shapes=img_shapes,
134
+ txt_seq_lens=txt_seq_lens,
135
+ return_dict=False,
136
+ **kwargs)[0]
137
+
138
+ output = output.permute(0, 2, 1).reshape(bs, self.out_channels, h, w)
139
+ return self.unpatchify(output)
models/lakonlab/models/architecture/diffusers/sd3.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from accelerate import init_empty_weights
4
+ from diffusers.models import SD3Transformer2DModel as _SD3Transformer2DModel
5
+ from peft import LoraConfig
6
+ from mmgen.models.builder import MODULES
7
+ from mmgen.utils import get_root_logger
8
+ from ..utils import flex_freeze
9
+ from lakonlab.runner.checkpoint import load_checkpoint
10
+
11
+
12
+ @MODULES.register_module()
13
+ class SD3Transformer2DModel(_SD3Transformer2DModel):
14
+
15
+ def __init__(
16
+ self,
17
+ *args,
18
+ freeze=False,
19
+ freeze_exclude=[],
20
+ pretrained=None,
21
+ torch_dtype='float32',
22
+ freeze_exclude_fp32=True,
23
+ freeze_exclude_autocast_dtype='float32',
24
+ checkpointing=True,
25
+ use_lora=False,
26
+ lora_target_modules=None,
27
+ lora_rank=16,
28
+ **kwargs):
29
+ with init_empty_weights():
30
+ super().__init__(*args, **kwargs)
31
+ self.init_weights(pretrained)
32
+
33
+ self.use_lora = use_lora
34
+ self.lora_target_modules = lora_target_modules
35
+ self.lora_rank = lora_rank
36
+ if self.use_lora:
37
+ transformer_lora_config = LoraConfig(
38
+ r=lora_rank,
39
+ lora_alpha=lora_rank,
40
+ init_lora_weights='gaussian',
41
+ target_modules=lora_target_modules,
42
+ )
43
+ self.add_adapter(transformer_lora_config)
44
+
45
+ if torch_dtype is not None:
46
+ self.to(getattr(torch, torch_dtype))
47
+
48
+ self.freeze = freeze
49
+ if self.freeze:
50
+ flex_freeze(
51
+ self,
52
+ exclude_keys=freeze_exclude,
53
+ exclude_fp32=freeze_exclude_fp32,
54
+ exclude_autocast_dtype=freeze_exclude_autocast_dtype)
55
+
56
+ if checkpointing:
57
+ self.enable_gradient_checkpointing()
58
+
59
+ def init_weights(self, pretrained=None):
60
+ if pretrained is not None:
61
+ logger = get_root_logger()
62
+ load_checkpoint(
63
+ self, pretrained, map_location='cpu', strict=False, logger=logger, assign=True)
64
+
65
+ def forward(
66
+ self,
67
+ hidden_states: torch.Tensor,
68
+ timestep: torch.Tensor,
69
+ encoder_hidden_states: torch.Tensor = None,
70
+ pooled_projections: torch.Tensor = None,
71
+ **kwargs):
72
+ dtype = hidden_states.dtype
73
+
74
+ return super().forward(
75
+ hidden_states=hidden_states,
76
+ encoder_hidden_states=encoder_hidden_states.to(dtype),
77
+ pooled_projections=pooled_projections.to(dtype),
78
+ timestep=timestep,
79
+ return_dict=False,
80
+ **kwargs)[0]
models/lakonlab/models/architecture/diffusers/unet.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from typing import Dict, Any, Optional, Union, Tuple
5
+ from collections import OrderedDict
6
+ from diffusers.models import UNet2DConditionModel as _UNet2DConditionModel
7
+ from mmcv.runner import _load_checkpoint, load_state_dict
8
+ from mmgen.models.builder import MODULES
9
+ from mmgen.utils import get_root_logger
10
+ from ..utils import flex_freeze
11
+
12
+
13
+ def ceildiv(a, b):
14
+ return -(a // -b)
15
+
16
+
17
+ def unet_enc(
18
+ unet,
19
+ sample: torch.FloatTensor,
20
+ timestep: Union[torch.Tensor, float, int],
21
+ encoder_hidden_states: torch.Tensor,
22
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
23
+ added_cond_kwargs=None):
24
+ # 0. center input if necessary
25
+ if unet.config.center_input_sample:
26
+ sample = 2 * sample - 1.0
27
+
28
+ # 1. time
29
+ t_emb = unet.get_time_embed(sample=sample, timestep=timestep)
30
+ emb = unet.time_embedding(t_emb)
31
+ aug_emb = unet.get_aug_embed(
32
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs)
33
+ emb = emb + aug_emb if aug_emb is not None else emb
34
+
35
+ if unet.time_embed_act is not None:
36
+ emb = unet.time_embed_act(emb)
37
+
38
+ encoder_hidden_states = unet.process_encoder_hidden_states(
39
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs)
40
+
41
+ # 2. pre-process
42
+ sample = unet.conv_in(sample)
43
+
44
+ # 3. down
45
+ down_block_res_samples = (sample,)
46
+ for downsample_block in unet.down_blocks:
47
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
48
+ sample, res_samples = downsample_block(
49
+ hidden_states=sample,
50
+ temb=emb,
51
+ encoder_hidden_states=encoder_hidden_states,
52
+ cross_attention_kwargs=cross_attention_kwargs,
53
+ )
54
+ else:
55
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
56
+
57
+ down_block_res_samples += res_samples
58
+
59
+ return emb, down_block_res_samples, sample
60
+
61
+
62
+ def unet_dec(
63
+ unet,
64
+ emb,
65
+ down_block_res_samples,
66
+ sample,
67
+ encoder_hidden_states: torch.Tensor,
68
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
69
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
70
+ mid_block_additional_residual: Optional[torch.Tensor] = None):
71
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
72
+
73
+ if is_controlnet:
74
+ new_down_block_res_samples = ()
75
+
76
+ for down_block_res_sample, down_block_additional_residual in zip(
77
+ down_block_res_samples, down_block_additional_residuals):
78
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
79
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
80
+
81
+ down_block_res_samples = new_down_block_res_samples
82
+
83
+ # 4. mid
84
+ if unet.mid_block is not None:
85
+ if hasattr(unet.mid_block, "has_cross_attention") and unet.mid_block.has_cross_attention:
86
+ sample = unet.mid_block(
87
+ sample,
88
+ emb,
89
+ encoder_hidden_states=encoder_hidden_states,
90
+ cross_attention_kwargs=cross_attention_kwargs,
91
+ )
92
+ else:
93
+ sample = unet.mid_block(sample, emb)
94
+
95
+ if is_controlnet:
96
+ sample = sample + mid_block_additional_residual
97
+
98
+ # 5. up
99
+ for i, upsample_block in enumerate(unet.up_blocks):
100
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
101
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
102
+
103
+ if hasattr(upsample_block, 'has_cross_attention') and upsample_block.has_cross_attention:
104
+ sample = upsample_block(
105
+ hidden_states=sample,
106
+ temb=emb,
107
+ res_hidden_states_tuple=res_samples,
108
+ encoder_hidden_states=encoder_hidden_states,
109
+ cross_attention_kwargs=cross_attention_kwargs,
110
+ )
111
+ else:
112
+ sample = upsample_block(
113
+ hidden_states=sample,
114
+ temb=emb,
115
+ res_hidden_states_tuple=res_samples,
116
+ )
117
+
118
+ # 6. post-process
119
+ if unet.conv_norm_out:
120
+ sample = unet.conv_norm_out(sample)
121
+ sample = unet.conv_act(sample)
122
+ sample = unet.conv_out(sample)
123
+
124
+ return sample
125
+
126
+
127
+ @MODULES.register_module()
128
+ class UNet2DConditionModel(_UNet2DConditionModel):
129
+ def __init__(self,
130
+ *args,
131
+ freeze=True,
132
+ freeze_exclude=[],
133
+ pretrained=None,
134
+ torch_dtype='float32',
135
+ freeze_exclude_fp32=True,
136
+ freeze_exclude_autocast_dtype='float32',
137
+ **kwargs):
138
+ super().__init__(*args, **kwargs)
139
+
140
+ self.init_weights(pretrained)
141
+ if torch_dtype is not None:
142
+ self.to(getattr(torch, torch_dtype))
143
+
144
+ self.set_use_memory_efficient_attention_xformers(
145
+ not hasattr(torch.nn.functional, 'scaled_dot_product_attention'))
146
+
147
+ self.freeze = freeze
148
+ if self.freeze:
149
+ flex_freeze(
150
+ self,
151
+ exclude_keys=freeze_exclude,
152
+ exclude_fp32=freeze_exclude_fp32,
153
+ exclude_autocast_dtype=freeze_exclude_autocast_dtype)
154
+
155
+ def init_weights(self, pretrained):
156
+ if pretrained is not None:
157
+ logger = get_root_logger()
158
+ # load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
159
+ checkpoint = _load_checkpoint(pretrained, map_location='cpu', logger=logger)
160
+ if 'state_dict' in checkpoint:
161
+ state_dict = checkpoint['state_dict']
162
+ else:
163
+ state_dict = checkpoint
164
+ metadata = getattr(state_dict, '_metadata', OrderedDict())
165
+ state_dict._metadata = metadata
166
+ assert self.conv_in.weight.shape[1] == self.conv_out.weight.shape[0]
167
+ if state_dict['conv_in.weight'].size() != self.conv_in.weight.size():
168
+ assert state_dict['conv_in.weight'].shape[1] == state_dict['conv_out.weight'].shape[0]
169
+ src_chn = state_dict['conv_in.weight'].shape[1]
170
+ tgt_chn = self.conv_in.weight.shape[1]
171
+ assert src_chn < tgt_chn
172
+ convert_mat_out = torch.tile(torch.eye(src_chn), (ceildiv(tgt_chn, src_chn), 1))
173
+ convert_mat_out = convert_mat_out[:tgt_chn]
174
+ convert_mat_in = F.normalize(convert_mat_out.pinverse(), dim=-1)
175
+ state_dict['conv_out.weight'] = torch.einsum(
176
+ 'ts,scxy->tcxy', convert_mat_out, state_dict['conv_out.weight'])
177
+ state_dict['conv_out.bias'] = torch.einsum(
178
+ 'ts,s->t', convert_mat_out, state_dict['conv_out.bias'])
179
+ state_dict['conv_in.weight'] = torch.einsum(
180
+ 'st,csxy->ctxy', convert_mat_in, state_dict['conv_in.weight'])
181
+ load_state_dict(self, state_dict, logger=logger)
182
+
183
+ def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
184
+ dtype = sample.dtype
185
+ return super().forward(
186
+ sample, timestep, encoder_hidden_states, return_dict=False, **kwargs)[0].to(dtype)
187
+
188
+ def forward_enc(self, sample, timestep, encoder_hidden_states, **kwargs):
189
+ return unet_enc(self, sample, timestep, encoder_hidden_states, **kwargs)
190
+
191
+ def forward_dec(self, emb, down_block_res_samples, sample, encoder_hidden_states, **kwargs):
192
+ return unet_dec(self, emb, down_block_res_samples, sample, encoder_hidden_states, **kwargs)
piFlow/lakonlab/models/architecture/diffusers/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pretrained import (
2
+ PretrainedVAE, PretrainedVAEDecoder, PretrainedVAEEncoder, PretrainedVAEQwenImage,
3
+ PretrainedFluxTextEncoder, PretrainedQwenImageTextEncoder, PretrainedStableDiffusion3TextEncoder)
4
+ from .unet import UNet2DConditionModel
5
+ from .flux import FluxTransformer2DModel
6
+ from .dit import DiTTransformer2DModelMod
7
+ from .sd3 import SD3Transformer2DModel
8
+ from .qwen import QwenImageTransformer2DModel
9
+
10
+ __all__ = [
11
+ 'PretrainedVAE', 'PretrainedVAEDecoder', 'PretrainedVAEEncoder', 'PretrainedFluxTextEncoder',
12
+ 'PretrainedQwenImageTextEncoder', 'UNet2DConditionModel', 'FluxTransformer2DModel',
13
+ 'DiTTransformer2DModelMod', 'SD3Transformer2DModel',
14
+ 'QwenImageTransformer2DModel', 'PretrainedVAEQwenImage', 'PretrainedStableDiffusion3TextEncoder',
15
+ ]
piFlow/lakonlab/models/architecture/diffusers/dit.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from typing import Any, Dict, Optional
6
+ from diffusers.models import DiTTransformer2DModel, ModelMixin
7
+ from diffusers.models.attention import BasicTransformerBlock, _chunked_feed_forward, Attention, FeedForward
8
+ from diffusers.models.embeddings import (
9
+ PatchEmbed, Timesteps, CombinedTimestepLabelEmbeddings, TimestepEmbedding, LabelEmbedding)
10
+ from diffusers.models.normalization import AdaLayerNormZero
11
+ from diffusers.configuration_utils import register_to_config
12
+ from mmcv.runner import load_checkpoint, _load_checkpoint, load_state_dict
13
+ from mmcv.cnn import constant_init, xavier_init
14
+ from mmgen.models.builder import MODULES
15
+ from mmgen.utils import get_root_logger
16
+ from ..utils import flex_freeze
17
+
18
+
19
+ class LabelEmbeddingMod(LabelEmbedding):
20
+ def __init__(self, num_classes, hidden_size, dropout_prob=0.0, use_cfg_embedding=True):
21
+ super(LabelEmbedding, self).__init__()
22
+ if dropout_prob > 0:
23
+ assert use_cfg_embedding
24
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
25
+ self.num_classes = num_classes
26
+ self.dropout_prob = dropout_prob
27
+
28
+
29
+ class CombinedTimestepLabelEmbeddingsMod(CombinedTimestepLabelEmbeddings):
30
+ """
31
+ Modified CombinedTimestepLabelEmbeddings for reproducing the original DiT (downscale_freq_shift=0).
32
+ """
33
+ def __init__(
34
+ self, num_classes, embedding_dim, class_dropout_prob=0.1, downscale_freq_shift=0, use_cfg_embedding=True):
35
+ super(CombinedTimestepLabelEmbeddings, self).__init__()
36
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=downscale_freq_shift)
37
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
38
+ self.class_embedder = LabelEmbeddingMod(num_classes, embedding_dim, class_dropout_prob, use_cfg_embedding)
39
+
40
+
41
+ class BasicTransformerBlockMod(BasicTransformerBlock):
42
+ """
43
+ Modified BasicTransformerBlock for reproducing the original DiT with shared time and class
44
+ embeddings across all layers.
45
+ """
46
+ def __init__(
47
+ self,
48
+ dim: int,
49
+ num_attention_heads: int,
50
+ attention_head_dim: int,
51
+ dropout=0.0,
52
+ cross_attention_dim: Optional[int] = None,
53
+ activation_fn: str = 'geglu',
54
+ num_embeds_ada_norm: Optional[int] = None,
55
+ attention_bias: bool = False,
56
+ only_cross_attention: bool = False,
57
+ double_self_attention: bool = False,
58
+ upcast_attention: bool = False,
59
+ norm_elementwise_affine: bool = True,
60
+ norm_type: str = 'layer_norm',
61
+ norm_eps: float = 1e-5,
62
+ final_dropout: bool = False,
63
+ attention_type: str = 'default',
64
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
65
+ ada_norm_bias: Optional[int] = None,
66
+ ff_inner_dim: Optional[int] = None,
67
+ ff_bias: bool = True,
68
+ attention_out_bias: bool = True):
69
+ super(BasicTransformerBlock, self).__init__()
70
+ self.only_cross_attention = only_cross_attention
71
+ self.norm_type = norm_type
72
+ self.num_embeds_ada_norm = num_embeds_ada_norm
73
+
74
+ assert self.norm_type == 'ada_norm_zero'
75
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
76
+ self.attn1 = Attention(
77
+ query_dim=dim,
78
+ heads=num_attention_heads,
79
+ dim_head=attention_head_dim,
80
+ dropout=dropout,
81
+ bias=attention_bias,
82
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
83
+ upcast_attention=upcast_attention,
84
+ out_bias=attention_out_bias,
85
+ )
86
+
87
+ self.norm2 = None
88
+ self.attn2 = None
89
+
90
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
91
+ self.ff = FeedForward(
92
+ dim,
93
+ dropout=dropout,
94
+ activation_fn=activation_fn,
95
+ final_dropout=final_dropout,
96
+ inner_dim=ff_inner_dim,
97
+ bias=ff_bias,
98
+ )
99
+
100
+ self._chunk_size = None
101
+ self._chunk_dim = 0
102
+
103
+ def forward(
104
+ self,
105
+ hidden_states: torch.Tensor,
106
+ attention_mask: Optional[torch.Tensor] = None,
107
+ encoder_hidden_states: Optional[torch.Tensor] = None,
108
+ encoder_attention_mask: Optional[torch.Tensor] = None,
109
+ timestep: Optional[torch.LongTensor] = None,
110
+ cross_attention_kwargs: Dict[str, Any] = None,
111
+ class_labels: Optional[torch.LongTensor] = None,
112
+ emb: Optional[torch.Tensor] = None,
113
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
114
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
115
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype, emb=emb)
116
+
117
+ if cross_attention_kwargs is None:
118
+ cross_attention_kwargs = dict()
119
+ attn_output = self.attn1(
120
+ norm_hidden_states,
121
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
122
+ attention_mask=attention_mask,
123
+ **cross_attention_kwargs)
124
+ attn_output = gate_msa.unsqueeze(1) * attn_output
125
+
126
+ hidden_states = attn_output + hidden_states
127
+ if hidden_states.ndim == 4:
128
+ hidden_states = hidden_states.squeeze(1)
129
+
130
+ norm_hidden_states = self.norm3(hidden_states)
131
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
132
+
133
+ if self._chunk_size is not None:
134
+ # "feed_forward_chunk_size" can be used to save memory
135
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
136
+ else:
137
+ ff_output = self.ff(norm_hidden_states)
138
+
139
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
140
+
141
+ hidden_states = ff_output + hidden_states
142
+ if hidden_states.ndim == 4:
143
+ hidden_states = hidden_states.squeeze(1)
144
+
145
+ return hidden_states
146
+
147
+
148
+ class _DiTTransformer2DModelMod(DiTTransformer2DModel):
149
+
150
+ @register_to_config
151
+ def __init__(
152
+ self,
153
+ class_dropout_prob=0.0,
154
+ num_attention_heads: int = 16,
155
+ attention_head_dim: int = 72,
156
+ in_channels: int = 4,
157
+ out_channels: Optional[int] = None,
158
+ num_layers: int = 28,
159
+ dropout: float = 0.0,
160
+ norm_num_groups: int = 32,
161
+ attention_bias: bool = True,
162
+ sample_size: int = 32,
163
+ patch_size: int = 2,
164
+ activation_fn: str = 'gelu-approximate',
165
+ num_embeds_ada_norm: Optional[int] = 1000,
166
+ upcast_attention: bool = False,
167
+ norm_type: str = 'ada_norm_zero',
168
+ norm_elementwise_affine: bool = False,
169
+ norm_eps: float = 1e-5):
170
+
171
+ super(DiTTransformer2DModel, self).__init__()
172
+
173
+ # Validate inputs.
174
+ if norm_type != "ada_norm_zero":
175
+ raise NotImplementedError(
176
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
177
+ )
178
+ elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None:
179
+ raise ValueError(
180
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
181
+ )
182
+
183
+ # Set some common variables used across the board.
184
+ self.attention_head_dim = attention_head_dim
185
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
186
+ self.out_channels = in_channels if out_channels is None else out_channels
187
+ self.gradient_checkpointing = False
188
+
189
+ # 2. Initialize the position embedding and transformer blocks.
190
+ self.height = self.config.sample_size
191
+ self.width = self.config.sample_size
192
+
193
+ self.patch_size = self.config.patch_size
194
+ self.pos_embed = PatchEmbed(
195
+ height=self.config.sample_size,
196
+ width=self.config.sample_size,
197
+ patch_size=self.config.patch_size,
198
+ in_channels=self.config.in_channels,
199
+ embed_dim=self.inner_dim)
200
+ self.emb = CombinedTimestepLabelEmbeddingsMod(
201
+ num_embeds_ada_norm, self.inner_dim, class_dropout_prob=0.0)
202
+
203
+ self.transformer_blocks = nn.ModuleList([
204
+ BasicTransformerBlockMod(
205
+ self.inner_dim,
206
+ self.config.num_attention_heads,
207
+ self.config.attention_head_dim,
208
+ dropout=self.config.dropout,
209
+ activation_fn=self.config.activation_fn,
210
+ num_embeds_ada_norm=None,
211
+ attention_bias=self.config.attention_bias,
212
+ upcast_attention=self.config.upcast_attention,
213
+ norm_type=norm_type,
214
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
215
+ norm_eps=self.config.norm_eps)
216
+ for _ in range(self.config.num_layers)])
217
+
218
+ # 3. Output blocks.
219
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
220
+ self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
221
+ self.proj_out_2 = nn.Linear(
222
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
223
+
224
+ # https://github.com/facebookresearch/DiT/blob/main/models.py
225
+ def init_weights(self):
226
+ for m in self.modules():
227
+ if isinstance(m, nn.Linear):
228
+ xavier_init(m, distribution='uniform')
229
+ elif isinstance(m, nn.Embedding):
230
+ torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
231
+
232
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d)
233
+ w = self.pos_embed.proj.weight.data
234
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
235
+ nn.init.constant_(self.pos_embed.proj.bias, 0)
236
+
237
+ # Zero-out adaLN modulation layers in DiT blocks
238
+ for m in self.modules():
239
+ if isinstance(m, AdaLayerNormZero):
240
+ constant_init(m.linear, val=0)
241
+
242
+ # Zero-out output layers
243
+ constant_init(self.proj_out_1, val=0)
244
+ constant_init(self.proj_out_2, val=0)
245
+
246
+ def forward(
247
+ self,
248
+ hidden_states: torch.Tensor,
249
+ timestep: Optional[torch.LongTensor] = None,
250
+ class_labels: Optional[torch.LongTensor] = None,
251
+ cross_attention_kwargs: Dict[str, Any] = None):
252
+ # 1. Input
253
+ bs, _, h, w = hidden_states.size()
254
+ height, width = h // self.patch_size, w // self.patch_size
255
+ hidden_states = self.pos_embed(hidden_states)
256
+
257
+ cond_emb = self.emb(
258
+ timestep, class_labels, hidden_dtype=hidden_states.dtype)
259
+ dropout_enabled = self.config.class_dropout_prob > 0 and self.training
260
+ if dropout_enabled:
261
+ uncond_emb = self.emb(timestep, torch.full_like(
262
+ class_labels, self.config.num_embeds_ada_norm), hidden_dtype=hidden_states.dtype)
263
+
264
+ # 2. Blocks
265
+ for block in self.transformer_blocks:
266
+ if dropout_enabled:
267
+ dropout_mask = torch.rand((bs, 1), device=hidden_states.device) < self.config.class_dropout_prob
268
+ emb = torch.where(dropout_mask, uncond_emb, cond_emb)
269
+ else:
270
+ emb = cond_emb
271
+
272
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
273
+
274
+ def create_custom_forward(module, return_dict=None):
275
+ def custom_forward(*inputs):
276
+ if return_dict is not None:
277
+ return module(*inputs, return_dict=return_dict)
278
+ else:
279
+ return module(*inputs)
280
+
281
+ return custom_forward
282
+
283
+ hidden_states = torch.utils.checkpoint.checkpoint(
284
+ create_custom_forward(block),
285
+ hidden_states,
286
+ None,
287
+ None,
288
+ None,
289
+ timestep,
290
+ cross_attention_kwargs,
291
+ class_labels,
292
+ emb,
293
+ use_reentrant=False)
294
+
295
+ else:
296
+ hidden_states = block(
297
+ hidden_states,
298
+ attention_mask=None,
299
+ encoder_hidden_states=None,
300
+ encoder_attention_mask=None,
301
+ timestep=timestep,
302
+ cross_attention_kwargs=cross_attention_kwargs,
303
+ class_labels=class_labels,
304
+ emb=emb)
305
+
306
+ # 3. Output
307
+ if dropout_enabled:
308
+ dropout_mask = torch.rand((bs, 1), device=hidden_states.device) < self.config.class_dropout_prob
309
+ emb = torch.where(dropout_mask, uncond_emb, cond_emb)
310
+ else:
311
+ emb = cond_emb
312
+ shift, scale = self.proj_out_1(F.silu(emb)).chunk(2, dim=1)
313
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
314
+ output = self.proj_out_2(hidden_states).reshape(
315
+ bs, height, width, self.patch_size, self.patch_size, self.out_channels
316
+ ).permute(0, 5, 1, 3, 2, 4).reshape(
317
+ bs, self.out_channels, height * self.patch_size, width * self.patch_size)
318
+
319
+ return output
320
+
321
+
322
+ @MODULES.register_module()
323
+ class DiTTransformer2DModelMod(_DiTTransformer2DModelMod):
324
+
325
+ def __init__(
326
+ self,
327
+ *args,
328
+ freeze=False,
329
+ freeze_exclude=[],
330
+ pretrained=None,
331
+ torch_dtype='float32',
332
+ autocast_dtype=None,
333
+ freeze_exclude_fp32=True,
334
+ freeze_exclude_autocast_dtype='float32',
335
+ checkpointing=True,
336
+ **kwargs):
337
+ super().__init__(*args, **kwargs)
338
+
339
+ self.init_weights(pretrained)
340
+
341
+ if autocast_dtype is not None:
342
+ assert torch_dtype == 'float32'
343
+ self.autocast_dtype = autocast_dtype
344
+
345
+ if torch_dtype is not None:
346
+ self.to(getattr(torch, torch_dtype))
347
+
348
+ self.freeze = freeze
349
+ if self.freeze:
350
+ flex_freeze(
351
+ self,
352
+ exclude_keys=freeze_exclude,
353
+ exclude_fp32=freeze_exclude_fp32,
354
+ exclude_autocast_dtype=freeze_exclude_autocast_dtype)
355
+
356
+ if checkpointing:
357
+ self.enable_gradient_checkpointing()
358
+
359
+ def init_weights(self, pretrained=None):
360
+ super().init_weights()
361
+ if pretrained is not None:
362
+ logger = get_root_logger()
363
+ # load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
364
+ checkpoint = _load_checkpoint(pretrained, map_location='cpu', logger=logger)
365
+ if 'state_dict' in checkpoint:
366
+ state_dict = checkpoint['state_dict']
367
+ else:
368
+ state_dict = checkpoint
369
+ # load from GMDiT V1 model with 1 Gaussian
370
+ p2 = self.config.patch_size * self.config.patch_size
371
+ ori_out_channels = p2 * self.out_channels
372
+ if 'proj_out_2.weight' in state_dict:
373
+ # if this is GMDiT V1 model with 1 Gaussian
374
+ if state_dict['proj_out_2.weight'].size(0) == p2 * (self.out_channels + 1):
375
+ state_dict['proj_out_2.weight'] = state_dict['proj_out_2.weight'].reshape(
376
+ p2, self.out_channels + 1, -1
377
+ )[:, :-1].reshape(ori_out_channels, -1)
378
+ # if this is original DiT with variance prediction
379
+ if state_dict['proj_out_2.weight'].size(0) == 2 * ori_out_channels:
380
+ state_dict['proj_out_2.weight'] = state_dict['proj_out_2.weight'].reshape(
381
+ p2, 2 * self.out_channels, -1
382
+ )[:, :self.out_channels].reshape(ori_out_channels, -1)
383
+ if 'proj_out_2.bias' in state_dict:
384
+ # if this is GMDiT V1 model with 1 Gaussian
385
+ if state_dict['proj_out_2.bias'].size(0) == p2 * (self.out_channels + 1):
386
+ state_dict['proj_out_2.bias'] = state_dict['proj_out_2.bias'].reshape(
387
+ p2, self.out_channels + 1
388
+ )[:, :-1].reshape(ori_out_channels)
389
+ # if this is original DiT with variance prediction
390
+ if state_dict['proj_out_2.bias'].size(0) == 2 * ori_out_channels:
391
+ state_dict['proj_out_2.bias'] = state_dict['proj_out_2.bias'].reshape(
392
+ p2, 2 * self.out_channels
393
+ )[:, :self.out_channels].reshape(ori_out_channels)
394
+ if 'emb.class_embedder.embedding_table.weight' not in state_dict \
395
+ and 'transformer_blocks.0.norm1.emb.class_embedder.embedding_table.weight' in state_dict:
396
+ # convert original diffusers DiT model to our modified DiT model with shared embeddings
397
+ keys_to_remove = []
398
+ state_update = {}
399
+ for k, v in state_dict.items():
400
+ if k.startswith('transformer_blocks.0.norm1.emb.'):
401
+ new_k = k.replace('transformer_blocks.0.norm1.', '')
402
+ state_update[new_k] = v
403
+ if k.startswith('transformer_blocks.') and '.norm1.emb.' in k:
404
+ keys_to_remove.append(k)
405
+ state_dict.update(state_update)
406
+ for k in keys_to_remove:
407
+ del state_dict[k]
408
+ load_state_dict(self, state_dict, logger=logger)
409
+
410
+ def forward(
411
+ self,
412
+ hidden_states: torch.Tensor,
413
+ timestep: Optional[torch.LongTensor] = None,
414
+ class_labels: Optional[torch.LongTensor] = None,
415
+ **kwargs):
416
+ if self.autocast_dtype is not None:
417
+ dtype = getattr(torch, self.autocast_dtype)
418
+ else:
419
+ dtype = hidden_states.dtype
420
+ with torch.autocast(
421
+ device_type='cuda',
422
+ enabled=self.autocast_dtype is not None,
423
+ dtype=dtype if self.autocast_dtype is not None else None):
424
+ return super().forward(
425
+ hidden_states.to(dtype),
426
+ timestep=timestep,
427
+ class_labels=class_labels,
428
+ **kwargs)
piFlow/lakonlab/models/architecture/diffusers/flux.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing import Optional
4
+ from accelerate import init_empty_weights
5
+ from diffusers.models import FluxTransformer2DModel as _FluxTransformer2DModel
6
+ from peft import LoraConfig
7
+ from mmgen.models.builder import MODULES
8
+ from mmgen.utils import get_root_logger
9
+ from ..utils import flex_freeze
10
+ from lakonlab.runner.checkpoint import load_checkpoint, _load_checkpoint
11
+
12
+
13
+ @MODULES.register_module()
14
+ class FluxTransformer2DModel(_FluxTransformer2DModel):
15
+
16
+ def __init__(
17
+ self,
18
+ *args,
19
+ patch_size=2,
20
+ freeze=False,
21
+ freeze_exclude=[],
22
+ pretrained=None,
23
+ pretrained_lora=None,
24
+ pretrained_lora_scale=1.0,
25
+ torch_dtype='float32',
26
+ freeze_exclude_fp32=True,
27
+ freeze_exclude_autocast_dtype='float32',
28
+ checkpointing=True,
29
+ use_lora=False,
30
+ lora_target_modules=None,
31
+ lora_rank=16,
32
+ **kwargs):
33
+ with init_empty_weights():
34
+ super().__init__(patch_size=1, *args, **kwargs)
35
+ self.patch_size = patch_size
36
+
37
+ self.init_weights(pretrained, pretrained_lora, pretrained_lora_scale)
38
+
39
+ self.use_lora = use_lora
40
+ self.lora_target_modules = lora_target_modules
41
+ self.lora_rank = lora_rank
42
+ if self.use_lora:
43
+ transformer_lora_config = LoraConfig(
44
+ r=lora_rank,
45
+ lora_alpha=lora_rank,
46
+ init_lora_weights='gaussian',
47
+ target_modules=lora_target_modules,
48
+ )
49
+ self.add_adapter(transformer_lora_config)
50
+
51
+ if torch_dtype is not None:
52
+ self.to(getattr(torch, torch_dtype))
53
+
54
+ self.freeze = freeze
55
+ if self.freeze:
56
+ flex_freeze(
57
+ self,
58
+ exclude_keys=freeze_exclude,
59
+ exclude_fp32=freeze_exclude_fp32,
60
+ exclude_autocast_dtype=freeze_exclude_autocast_dtype)
61
+
62
+ if checkpointing:
63
+ self.enable_gradient_checkpointing()
64
+
65
+ def init_weights(self, pretrained=None, pretrained_lora=None, pretrained_lora_scale=1.0):
66
+ if pretrained is not None:
67
+ logger = get_root_logger()
68
+ load_checkpoint(
69
+ self, pretrained, map_location='cpu', strict=False, logger=logger, assign=True)
70
+ if pretrained_lora is not None:
71
+ if not isinstance(pretrained_lora, (list, tuple)):
72
+ assert isinstance(pretrained_lora, str)
73
+ pretrained_lora = [pretrained_lora]
74
+ if not isinstance(pretrained_lora_scale, (list, tuple)):
75
+ assert isinstance(pretrained_lora_scale, (int, float))
76
+ pretrained_lora_scale = [pretrained_lora_scale]
77
+ for pretrained_lora_single, pretrained_lora_scale_single in zip(pretrained_lora, pretrained_lora_scale):
78
+ lora_state_dict = _load_checkpoint(
79
+ pretrained_lora_single, map_location='cpu', logger=logger)
80
+ self.load_lora_adapter(lora_state_dict)
81
+ self.fuse_lora(lora_scale=pretrained_lora_scale_single)
82
+ self.unload_lora()
83
+
84
+ @staticmethod
85
+ def _prepare_latent_image_ids(height, width, device, dtype):
86
+ """
87
+ Copied from Diffusers
88
+ """
89
+ latent_image_ids = torch.zeros(height, width, 3)
90
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
91
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
92
+
93
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
94
+
95
+ latent_image_ids = latent_image_ids.reshape(
96
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels)
97
+
98
+ return latent_image_ids.to(device=device, dtype=dtype)
99
+
100
+ def patchify(self, latents):
101
+ if self.patch_size > 1:
102
+ bs, c, h, w = latents.size()
103
+ latents = latents.reshape(
104
+ bs, c, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size
105
+ ).permute(
106
+ 0, 1, 3, 5, 2, 4
107
+ ).reshape(
108
+ bs, c * self.patch_size * self.patch_size, h // self.patch_size, w // self.patch_size)
109
+ return latents
110
+
111
+ def unpatchify(self, latents):
112
+ if self.patch_size > 1:
113
+ bs, c, h, w = latents.size()
114
+ latents = latents.reshape(
115
+ bs, c // (self.patch_size * self.patch_size), self.patch_size, self.patch_size, h, w
116
+ ).permute(
117
+ 0, 1, 4, 2, 5, 3
118
+ ).reshape(
119
+ bs, c // (self.patch_size * self.patch_size), h * self.patch_size, w * self.patch_size)
120
+ return latents
121
+
122
+ def forward(
123
+ self,
124
+ hidden_states: torch.Tensor,
125
+ timestep: torch.Tensor,
126
+ encoder_hidden_states: torch.Tensor = None,
127
+ pooled_projections: torch.Tensor = None,
128
+ mask: Optional[torch.Tensor] = None,
129
+ masked_image_latents: Optional[torch.Tensor] = None,
130
+ **kwargs):
131
+ hidden_states = self.patchify(hidden_states)
132
+ bs, c, h, w = hidden_states.size()
133
+ dtype = hidden_states.dtype
134
+ device = hidden_states.device
135
+ hidden_states = hidden_states.reshape(bs, c, h * w).permute(0, 2, 1)
136
+ img_ids = self._prepare_latent_image_ids(
137
+ h, w, device, dtype)
138
+ txt_ids = img_ids.new_zeros((encoder_hidden_states.shape[-2], 3))
139
+
140
+ # Flux fill
141
+ if mask is not None and masked_image_latents is not None:
142
+ hidden_states = torch.cat(
143
+ (hidden_states, masked_image_latents.to(dtype=dtype), mask.to(dtype=dtype)), dim=-1)
144
+
145
+ output = super().forward(
146
+ hidden_states=hidden_states,
147
+ encoder_hidden_states=encoder_hidden_states.to(dtype),
148
+ pooled_projections=pooled_projections.to(dtype),
149
+ timestep=timestep,
150
+ img_ids=img_ids,
151
+ txt_ids=txt_ids,
152
+ return_dict=False,
153
+ **kwargs)[0]
154
+
155
+ output = output.permute(0, 2, 1).reshape(bs, self.out_channels, h, w)
156
+ return self.unpatchify(output)
piFlow/lakonlab/models/architecture/diffusers/pretrained.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from diffusers.models import AutoencoderKL, AutoencoderKLQwenImage
7
+ from diffusers.pipelines import FluxPipeline, QwenImagePipeline, StableDiffusion3Pipeline
8
+ from mmgen.models.builder import MODULES
9
+
10
+ # Suppress truncation warnings from transformers and diffusers
11
+ for name in (
12
+ 'transformers.tokenization_utils_base',
13
+ 'transformers.tokenization_utils',
14
+ 'transformers.tokenization_utils_fast'):
15
+ logging.getLogger(name).setLevel(logging.ERROR)
16
+
17
+ for name, logger in logging.root.manager.loggerDict.items():
18
+ if isinstance(logger, logging.Logger) and (name.startswith('diffusers.pipelines.')):
19
+ logger.setLevel(logging.ERROR)
20
+
21
+
22
+ @MODULES.register_module()
23
+ class PretrainedVAE(nn.Module):
24
+ def __init__(self,
25
+ from_pretrained=None,
26
+ del_encoder=False,
27
+ del_decoder=False,
28
+ use_slicing=False,
29
+ freeze=True,
30
+ eval_mode=True,
31
+ torch_dtype='float32',
32
+ **kwargs):
33
+ super().__init__()
34
+ if torch_dtype is not None:
35
+ kwargs.update(torch_dtype=getattr(torch, torch_dtype))
36
+ self.vae = AutoencoderKL.from_pretrained(
37
+ from_pretrained, **kwargs)
38
+ if del_encoder:
39
+ del self.vae.encoder
40
+ if del_decoder:
41
+ del self.vae.decoder
42
+ if use_slicing:
43
+ self.vae.enable_slicing()
44
+ self.freeze = freeze
45
+ self.eval_mode = eval_mode
46
+ if self.freeze:
47
+ self.requires_grad_(False)
48
+ if self.eval_mode:
49
+ self.eval()
50
+ self.vae.set_use_memory_efficient_attention_xformers(
51
+ not hasattr(torch.nn.functional, 'scaled_dot_product_attention'))
52
+
53
+ def train(self, mode=True):
54
+ mode = mode and (not self.eval_mode)
55
+ return super().train(mode)
56
+
57
+ def forward(self, *args, **kwargs):
58
+ return self.vae(*args, return_dict=False, **kwargs)[0]
59
+
60
+ def encode(self, img):
61
+ scaling_factor = self.vae.config.scaling_factor
62
+ shift_factor = self.vae.config.shift_factor
63
+ if scaling_factor is None:
64
+ scaling_factor = 1.0
65
+ if shift_factor is None:
66
+ shift_factor = 0.0
67
+ return (self.vae.encode(img).latent_dist.sample() - shift_factor) * scaling_factor
68
+
69
+ def decode(self, code):
70
+ scaling_factor = self.vae.config.scaling_factor
71
+ shift_factor = self.vae.config.shift_factor
72
+ if scaling_factor is None:
73
+ scaling_factor = 1.0
74
+ if shift_factor is None:
75
+ shift_factor = 0.0
76
+ return self.vae.decode(code / scaling_factor + shift_factor, return_dict=False)[0]
77
+
78
+
79
+ @MODULES.register_module()
80
+ class PretrainedVAEDecoder(PretrainedVAE):
81
+ def __init__(self, **kwargs):
82
+ super().__init__(
83
+ del_encoder=True,
84
+ del_decoder=False,
85
+ **kwargs)
86
+
87
+ def forward(self, code):
88
+ return super().decode(code)
89
+
90
+
91
+ @MODULES.register_module()
92
+ class PretrainedVAEEncoder(PretrainedVAE):
93
+ def __init__(self, **kwargs):
94
+ super().__init__(
95
+ del_encoder=False,
96
+ del_decoder=True,
97
+ **kwargs)
98
+
99
+ def forward(self, img):
100
+ return super().encode(img)
101
+
102
+
103
+ @MODULES.register_module()
104
+ class PretrainedVAEQwenImage(nn.Module):
105
+ def __init__(self,
106
+ from_pretrained=None,
107
+ use_slicing=False,
108
+ freeze=True,
109
+ eval_mode=True,
110
+ torch_dtype='float32',
111
+ **kwargs):
112
+ super().__init__()
113
+ if torch_dtype is not None:
114
+ kwargs.update(torch_dtype=getattr(torch, torch_dtype))
115
+ self.vae = AutoencoderKLQwenImage.from_pretrained(
116
+ from_pretrained, **kwargs)
117
+ if use_slicing:
118
+ self.vae.enable_slicing()
119
+ self.freeze = freeze
120
+ self.eval_mode = eval_mode
121
+ if self.freeze:
122
+ self.requires_grad_(False)
123
+ if self.eval_mode:
124
+ self.eval()
125
+
126
+ def train(self, mode=True):
127
+ mode = mode and (not self.eval_mode)
128
+ return super().train(mode)
129
+
130
+ def forward(self, *args, **kwargs):
131
+ return self.vae(*args, return_dict=False, **kwargs)[0]
132
+
133
+ def encode(self, img):
134
+ device = img.device
135
+ dtype = img.dtype
136
+ latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=dtype).view(
137
+ 1, self.vae.config.z_dim, 1, 1, 1)
138
+ latents_std = torch.tensor(self.vae.config.latents_std, device=device, dtype=dtype).view(
139
+ 1, self.vae.config.z_dim, 1, 1, 1)
140
+ return ((self.vae.encode(img.unsqueeze(-3)).latent_dist.sample() - latents_mean) / latents_std).squeeze(-3)
141
+
142
+ def decode(self, code):
143
+ device = code.device
144
+ dtype = code.dtype
145
+ latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=dtype).view(
146
+ 1, self.vae.config.z_dim, 1, 1, 1)
147
+ latents_std = torch.tensor(self.vae.config.latents_std, device=device, dtype=dtype).view(
148
+ 1, self.vae.config.z_dim, 1, 1, 1)
149
+ return self.vae.decode(code.unsqueeze(-3) * latents_std + latents_mean, return_dict=False)[0].squeeze(-3)
150
+
151
+
152
+ @MODULES.register_module()
153
+ class PretrainedFluxTextEncoder(nn.Module):
154
+ def __init__(self,
155
+ from_pretrained='black-forest-labs/FLUX.1-dev',
156
+ freeze=True,
157
+ eval_mode=True,
158
+ torch_dtype='bfloat16',
159
+ max_sequence_length=512,
160
+ **kwargs):
161
+ super().__init__()
162
+ self.max_sequence_length = max_sequence_length
163
+ self.pipeline = FluxPipeline.from_pretrained(
164
+ from_pretrained,
165
+ scheduler=None,
166
+ vae=None,
167
+ transformer=None,
168
+ image_encoder=None,
169
+ feature_extractor=None,
170
+ torch_dtype=getattr(torch, torch_dtype),
171
+ **kwargs)
172
+ self.text_encoder = self.pipeline.text_encoder
173
+ self.text_encoder_2 = self.pipeline.text_encoder_2
174
+ self.freeze = freeze
175
+ self.eval_mode = eval_mode
176
+ if self.freeze:
177
+ self.requires_grad_(False)
178
+ if self.eval_mode:
179
+ self.eval()
180
+
181
+ def train(self, mode=True):
182
+ mode = mode and (not self.eval_mode)
183
+ return super().train(mode)
184
+
185
+ def forward(self, prompt, prompt_2=None):
186
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.pipeline.encode_prompt(
187
+ prompt, prompt_2=prompt_2, max_sequence_length=self.max_sequence_length)
188
+ return dict(
189
+ encoder_hidden_states=prompt_embeds,
190
+ pooled_projections=pooled_prompt_embeds)
191
+
192
+
193
+ @MODULES.register_module()
194
+ class PretrainedQwenImageTextEncoder(nn.Module):
195
+ def __init__(self,
196
+ from_pretrained='Qwen/Qwen-Image',
197
+ freeze=True,
198
+ eval_mode=True,
199
+ torch_dtype='bfloat16',
200
+ max_sequence_length=512,
201
+ pad_seq_len=None,
202
+ **kwargs):
203
+ super().__init__()
204
+ self.max_sequence_length = max_sequence_length
205
+ if pad_seq_len is not None:
206
+ assert pad_seq_len >= max_sequence_length
207
+ self.pad_seq_len = pad_seq_len
208
+ self.pipeline = QwenImagePipeline.from_pretrained(
209
+ from_pretrained,
210
+ scheduler=None,
211
+ vae=None,
212
+ transformer=None,
213
+ torch_dtype=getattr(torch, torch_dtype),
214
+ **kwargs)
215
+ self.text_encoder = self.pipeline.text_encoder
216
+ self.freeze = freeze
217
+ self.eval_mode = eval_mode
218
+ if self.freeze:
219
+ self.requires_grad_(False)
220
+ if self.eval_mode:
221
+ self.eval()
222
+
223
+ def train(self, mode=True):
224
+ mode = mode and (not self.eval_mode)
225
+ return super().train(mode)
226
+
227
+ def forward(self, prompt):
228
+ prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
229
+ prompt, max_sequence_length=self.max_sequence_length)
230
+ if self.pad_seq_len is not None:
231
+ pad_len = self.pad_seq_len - prompt_embeds.size(1)
232
+ prompt_embeds = F.pad(
233
+ prompt_embeds, (0, 0, 0, pad_len), value=0.0)
234
+ prompt_embeds_mask = F.pad(
235
+ prompt_embeds_mask, (0, pad_len), value=0.0)
236
+ return dict(
237
+ encoder_hidden_states=prompt_embeds,
238
+ encoder_hidden_states_mask=prompt_embeds_mask)
239
+
240
+
241
+ @MODULES.register_module()
242
+ class PretrainedStableDiffusion3TextEncoder(nn.Module):
243
+ def __init__(self,
244
+ from_pretrained='stabilityai/stable-diffusion-3.5-large',
245
+ freeze=True,
246
+ eval_mode=True,
247
+ torch_dtype='float32',
248
+ max_sequence_length=256,
249
+ **kwargs):
250
+ super().__init__()
251
+ self.max_sequence_length = max_sequence_length
252
+ self.pipeline = StableDiffusion3Pipeline.from_pretrained(
253
+ from_pretrained,
254
+ scheduler=None,
255
+ vae=None,
256
+ transformer=None,
257
+ image_encoder=None,
258
+ feature_extractor=None,
259
+ torch_dtype=getattr(torch, torch_dtype),
260
+ **kwargs)
261
+ self.text_encoder = self.pipeline.text_encoder
262
+ self.text_encoder_2 = self.pipeline.text_encoder_2
263
+ self.text_encoder_3 = self.pipeline.text_encoder_3
264
+ self.freeze = freeze
265
+ self.eval_mode = eval_mode
266
+ if self.freeze:
267
+ self.requires_grad_(False)
268
+ if self.eval_mode:
269
+ self.eval()
270
+
271
+ def train(self, mode=True):
272
+ mode = mode and (not self.eval_mode)
273
+ return super().train(mode)
274
+
275
+ def forward(self, prompt, prompt_2=None, prompt_3=None):
276
+ prompt_embeds, _, pooled_prompt_embeds, _ = self.pipeline.encode_prompt(
277
+ prompt, prompt_2=prompt_2, prompt_3=prompt_3, do_classifier_free_guidance=False,
278
+ max_sequence_length=self.max_sequence_length)
279
+ return dict(
280
+ encoder_hidden_states=prompt_embeds,
281
+ pooled_projections=pooled_prompt_embeds)
piFlow/lakonlab/models/architecture/diffusers/qwen.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from accelerate import init_empty_weights
4
+ from diffusers.models import QwenImageTransformer2DModel as _QwenImageTransformer2DModel
5
+ from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_qwen_lora_to_diffusers
6
+ from peft import LoraConfig
7
+ from mmgen.models.builder import MODULES
8
+ from mmgen.utils import get_root_logger
9
+ from ..utils import flex_freeze
10
+ from lakonlab.runner.checkpoint import load_checkpoint, _load_checkpoint
11
+
12
+
13
+ @MODULES.register_module()
14
+ class QwenImageTransformer2DModel(_QwenImageTransformer2DModel):
15
+
16
+ def __init__(
17
+ self,
18
+ *args,
19
+ patch_size=2,
20
+ freeze=False,
21
+ freeze_exclude=[],
22
+ pretrained=None,
23
+ pretrained_lora=None,
24
+ pretrained_lora_scale=1.0,
25
+ torch_dtype='float32',
26
+ freeze_exclude_fp32=True,
27
+ freeze_exclude_autocast_dtype='float32',
28
+ checkpointing=True,
29
+ use_lora=False,
30
+ lora_target_modules=None,
31
+ lora_rank=16,
32
+ **kwargs):
33
+ with init_empty_weights():
34
+ super().__init__(*args, patch_size=1, **kwargs)
35
+ self.patch_size = patch_size
36
+
37
+ self.init_weights(pretrained, pretrained_lora, pretrained_lora_scale)
38
+
39
+ self.use_lora = use_lora
40
+ self.lora_target_modules = lora_target_modules
41
+ self.lora_rank = lora_rank
42
+ if self.use_lora:
43
+ transformer_lora_config = LoraConfig(
44
+ r=lora_rank,
45
+ lora_alpha=lora_rank,
46
+ init_lora_weights='gaussian',
47
+ target_modules=lora_target_modules,
48
+ )
49
+ self.add_adapter(transformer_lora_config)
50
+
51
+ if torch_dtype is not None:
52
+ self.to(getattr(torch, torch_dtype))
53
+
54
+ self.freeze = freeze
55
+ if self.freeze:
56
+ flex_freeze(
57
+ self,
58
+ exclude_keys=freeze_exclude,
59
+ exclude_fp32=freeze_exclude_fp32,
60
+ exclude_autocast_dtype=freeze_exclude_autocast_dtype)
61
+
62
+ if checkpointing:
63
+ self.enable_gradient_checkpointing()
64
+
65
+ def init_weights(self, pretrained=None, pretrained_lora=None, pretrained_lora_scale=1.0):
66
+ if pretrained is not None:
67
+ logger = get_root_logger()
68
+ load_checkpoint(
69
+ self, pretrained, map_location='cpu', strict=False, logger=logger, assign=True)
70
+ if pretrained_lora is not None:
71
+ if not isinstance(pretrained_lora, (list, tuple)):
72
+ assert isinstance(pretrained_lora, str)
73
+ pretrained_lora = [pretrained_lora]
74
+ if not isinstance(pretrained_lora_scale, (list, tuple)):
75
+ assert isinstance(pretrained_lora_scale, (int, float))
76
+ pretrained_lora_scale = [pretrained_lora_scale]
77
+ for pretrained_lora_single, pretrained_lora_scale_single in zip(pretrained_lora, pretrained_lora_scale):
78
+ lora_state_dict = _load_checkpoint(
79
+ pretrained_lora_single, map_location='cpu', logger=logger)
80
+ lora_state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(lora_state_dict)
81
+ self.load_lora_adapter(lora_state_dict)
82
+ self.fuse_lora(lora_scale=pretrained_lora_scale_single)
83
+ self.unload_lora()
84
+
85
+ def patchify(self, latents):
86
+ if self.patch_size > 1:
87
+ bs, c, h, w = latents.size()
88
+ latents = latents.reshape(
89
+ bs, c, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size
90
+ ).permute(
91
+ 0, 1, 3, 5, 2, 4
92
+ ).reshape(
93
+ bs, c * self.patch_size * self.patch_size, h // self.patch_size, w // self.patch_size)
94
+ return latents
95
+
96
+ def unpatchify(self, latents):
97
+ if self.patch_size > 1:
98
+ bs, c, h, w = latents.size()
99
+ latents = latents.reshape(
100
+ bs, c // (self.patch_size * self.patch_size), self.patch_size, self.patch_size, h, w
101
+ ).permute(
102
+ 0, 1, 4, 2, 5, 3
103
+ ).reshape(
104
+ bs, c // (self.patch_size * self.patch_size), h * self.patch_size, w * self.patch_size)
105
+ return latents
106
+
107
+ def forward(
108
+ self,
109
+ hidden_states: torch.Tensor,
110
+ timestep: torch.Tensor,
111
+ encoder_hidden_states: torch.Tensor = None,
112
+ encoder_hidden_states_mask: torch.Tensor = None,
113
+ **kwargs):
114
+ hidden_states = self.patchify(hidden_states)
115
+ bs, c, h, w = hidden_states.size()
116
+ dtype = hidden_states.dtype
117
+ hidden_states = hidden_states.reshape(bs, c, h * w).permute(0, 2, 1)
118
+ img_shapes = [[(1, h, w)]]
119
+ if encoder_hidden_states_mask is not None:
120
+ txt_seq_lens = encoder_hidden_states_mask.sum(dim=1)
121
+ max_txt_seq_len = txt_seq_lens.max()
122
+ encoder_hidden_states = encoder_hidden_states[:, :max_txt_seq_len]
123
+ encoder_hidden_states_mask = encoder_hidden_states_mask[:, :max_txt_seq_len]
124
+ txt_seq_lens = txt_seq_lens.tolist()
125
+ else:
126
+ txt_seq_lens = None
127
+
128
+ output = super().forward(
129
+ hidden_states=hidden_states,
130
+ encoder_hidden_states=encoder_hidden_states.to(dtype),
131
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
132
+ timestep=timestep,
133
+ img_shapes=img_shapes,
134
+ txt_seq_lens=txt_seq_lens,
135
+ return_dict=False,
136
+ **kwargs)[0]
137
+
138
+ output = output.permute(0, 2, 1).reshape(bs, self.out_channels, h, w)
139
+ return self.unpatchify(output)
piFlow/lakonlab/models/architecture/diffusers/sd3.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from accelerate import init_empty_weights
4
+ from diffusers.models import SD3Transformer2DModel as _SD3Transformer2DModel
5
+ from peft import LoraConfig
6
+ from mmgen.models.builder import MODULES
7
+ from mmgen.utils import get_root_logger
8
+ from ..utils import flex_freeze
9
+ from lakonlab.runner.checkpoint import load_checkpoint
10
+
11
+
12
+ @MODULES.register_module()
13
+ class SD3Transformer2DModel(_SD3Transformer2DModel):
14
+
15
+ def __init__(
16
+ self,
17
+ *args,
18
+ freeze=False,
19
+ freeze_exclude=[],
20
+ pretrained=None,
21
+ torch_dtype='float32',
22
+ freeze_exclude_fp32=True,
23
+ freeze_exclude_autocast_dtype='float32',
24
+ checkpointing=True,
25
+ use_lora=False,
26
+ lora_target_modules=None,
27
+ lora_rank=16,
28
+ **kwargs):
29
+ with init_empty_weights():
30
+ super().__init__(*args, **kwargs)
31
+ self.init_weights(pretrained)
32
+
33
+ self.use_lora = use_lora
34
+ self.lora_target_modules = lora_target_modules
35
+ self.lora_rank = lora_rank
36
+ if self.use_lora:
37
+ transformer_lora_config = LoraConfig(
38
+ r=lora_rank,
39
+ lora_alpha=lora_rank,
40
+ init_lora_weights='gaussian',
41
+ target_modules=lora_target_modules,
42
+ )
43
+ self.add_adapter(transformer_lora_config)
44
+
45
+ if torch_dtype is not None:
46
+ self.to(getattr(torch, torch_dtype))
47
+
48
+ self.freeze = freeze
49
+ if self.freeze:
50
+ flex_freeze(
51
+ self,
52
+ exclude_keys=freeze_exclude,
53
+ exclude_fp32=freeze_exclude_fp32,
54
+ exclude_autocast_dtype=freeze_exclude_autocast_dtype)
55
+
56
+ if checkpointing:
57
+ self.enable_gradient_checkpointing()
58
+
59
+ def init_weights(self, pretrained=None):
60
+ if pretrained is not None:
61
+ logger = get_root_logger()
62
+ load_checkpoint(
63
+ self, pretrained, map_location='cpu', strict=False, logger=logger, assign=True)
64
+
65
+ def forward(
66
+ self,
67
+ hidden_states: torch.Tensor,
68
+ timestep: torch.Tensor,
69
+ encoder_hidden_states: torch.Tensor = None,
70
+ pooled_projections: torch.Tensor = None,
71
+ **kwargs):
72
+ dtype = hidden_states.dtype
73
+
74
+ return super().forward(
75
+ hidden_states=hidden_states,
76
+ encoder_hidden_states=encoder_hidden_states.to(dtype),
77
+ pooled_projections=pooled_projections.to(dtype),
78
+ timestep=timestep,
79
+ return_dict=False,
80
+ **kwargs)[0]
piFlow/lakonlab/models/architecture/diffusers/unet.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from typing import Dict, Any, Optional, Union, Tuple
5
+ from collections import OrderedDict
6
+ from diffusers.models import UNet2DConditionModel as _UNet2DConditionModel
7
+ from mmcv.runner import _load_checkpoint, load_state_dict
8
+ from mmgen.models.builder import MODULES
9
+ from mmgen.utils import get_root_logger
10
+ from ..utils import flex_freeze
11
+
12
+
13
+ def ceildiv(a, b):
14
+ return -(a // -b)
15
+
16
+
17
+ def unet_enc(
18
+ unet,
19
+ sample: torch.FloatTensor,
20
+ timestep: Union[torch.Tensor, float, int],
21
+ encoder_hidden_states: torch.Tensor,
22
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
23
+ added_cond_kwargs=None):
24
+ # 0. center input if necessary
25
+ if unet.config.center_input_sample:
26
+ sample = 2 * sample - 1.0
27
+
28
+ # 1. time
29
+ t_emb = unet.get_time_embed(sample=sample, timestep=timestep)
30
+ emb = unet.time_embedding(t_emb)
31
+ aug_emb = unet.get_aug_embed(
32
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs)
33
+ emb = emb + aug_emb if aug_emb is not None else emb
34
+
35
+ if unet.time_embed_act is not None:
36
+ emb = unet.time_embed_act(emb)
37
+
38
+ encoder_hidden_states = unet.process_encoder_hidden_states(
39
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs)
40
+
41
+ # 2. pre-process
42
+ sample = unet.conv_in(sample)
43
+
44
+ # 3. down
45
+ down_block_res_samples = (sample,)
46
+ for downsample_block in unet.down_blocks:
47
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
48
+ sample, res_samples = downsample_block(
49
+ hidden_states=sample,
50
+ temb=emb,
51
+ encoder_hidden_states=encoder_hidden_states,
52
+ cross_attention_kwargs=cross_attention_kwargs,
53
+ )
54
+ else:
55
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
56
+
57
+ down_block_res_samples += res_samples
58
+
59
+ return emb, down_block_res_samples, sample
60
+
61
+
62
+ def unet_dec(
63
+ unet,
64
+ emb,
65
+ down_block_res_samples,
66
+ sample,
67
+ encoder_hidden_states: torch.Tensor,
68
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
69
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
70
+ mid_block_additional_residual: Optional[torch.Tensor] = None):
71
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
72
+
73
+ if is_controlnet:
74
+ new_down_block_res_samples = ()
75
+
76
+ for down_block_res_sample, down_block_additional_residual in zip(
77
+ down_block_res_samples, down_block_additional_residuals):
78
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
79
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
80
+
81
+ down_block_res_samples = new_down_block_res_samples
82
+
83
+ # 4. mid
84
+ if unet.mid_block is not None:
85
+ if hasattr(unet.mid_block, "has_cross_attention") and unet.mid_block.has_cross_attention:
86
+ sample = unet.mid_block(
87
+ sample,
88
+ emb,
89
+ encoder_hidden_states=encoder_hidden_states,
90
+ cross_attention_kwargs=cross_attention_kwargs,
91
+ )
92
+ else:
93
+ sample = unet.mid_block(sample, emb)
94
+
95
+ if is_controlnet:
96
+ sample = sample + mid_block_additional_residual
97
+
98
+ # 5. up
99
+ for i, upsample_block in enumerate(unet.up_blocks):
100
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
101
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
102
+
103
+ if hasattr(upsample_block, 'has_cross_attention') and upsample_block.has_cross_attention:
104
+ sample = upsample_block(
105
+ hidden_states=sample,
106
+ temb=emb,
107
+ res_hidden_states_tuple=res_samples,
108
+ encoder_hidden_states=encoder_hidden_states,
109
+ cross_attention_kwargs=cross_attention_kwargs,
110
+ )
111
+ else:
112
+ sample = upsample_block(
113
+ hidden_states=sample,
114
+ temb=emb,
115
+ res_hidden_states_tuple=res_samples,
116
+ )
117
+
118
+ # 6. post-process
119
+ if unet.conv_norm_out:
120
+ sample = unet.conv_norm_out(sample)
121
+ sample = unet.conv_act(sample)
122
+ sample = unet.conv_out(sample)
123
+
124
+ return sample
125
+
126
+
127
+ @MODULES.register_module()
128
+ class UNet2DConditionModel(_UNet2DConditionModel):
129
+ def __init__(self,
130
+ *args,
131
+ freeze=True,
132
+ freeze_exclude=[],
133
+ pretrained=None,
134
+ torch_dtype='float32',
135
+ freeze_exclude_fp32=True,
136
+ freeze_exclude_autocast_dtype='float32',
137
+ **kwargs):
138
+ super().__init__(*args, **kwargs)
139
+
140
+ self.init_weights(pretrained)
141
+ if torch_dtype is not None:
142
+ self.to(getattr(torch, torch_dtype))
143
+
144
+ self.set_use_memory_efficient_attention_xformers(
145
+ not hasattr(torch.nn.functional, 'scaled_dot_product_attention'))
146
+
147
+ self.freeze = freeze
148
+ if self.freeze:
149
+ flex_freeze(
150
+ self,
151
+ exclude_keys=freeze_exclude,
152
+ exclude_fp32=freeze_exclude_fp32,
153
+ exclude_autocast_dtype=freeze_exclude_autocast_dtype)
154
+
155
+ def init_weights(self, pretrained):
156
+ if pretrained is not None:
157
+ logger = get_root_logger()
158
+ # load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
159
+ checkpoint = _load_checkpoint(pretrained, map_location='cpu', logger=logger)
160
+ if 'state_dict' in checkpoint:
161
+ state_dict = checkpoint['state_dict']
162
+ else:
163
+ state_dict = checkpoint
164
+ metadata = getattr(state_dict, '_metadata', OrderedDict())
165
+ state_dict._metadata = metadata
166
+ assert self.conv_in.weight.shape[1] == self.conv_out.weight.shape[0]
167
+ if state_dict['conv_in.weight'].size() != self.conv_in.weight.size():
168
+ assert state_dict['conv_in.weight'].shape[1] == state_dict['conv_out.weight'].shape[0]
169
+ src_chn = state_dict['conv_in.weight'].shape[1]
170
+ tgt_chn = self.conv_in.weight.shape[1]
171
+ assert src_chn < tgt_chn
172
+ convert_mat_out = torch.tile(torch.eye(src_chn), (ceildiv(tgt_chn, src_chn), 1))
173
+ convert_mat_out = convert_mat_out[:tgt_chn]
174
+ convert_mat_in = F.normalize(convert_mat_out.pinverse(), dim=-1)
175
+ state_dict['conv_out.weight'] = torch.einsum(
176
+ 'ts,scxy->tcxy', convert_mat_out, state_dict['conv_out.weight'])
177
+ state_dict['conv_out.bias'] = torch.einsum(
178
+ 'ts,s->t', convert_mat_out, state_dict['conv_out.bias'])
179
+ state_dict['conv_in.weight'] = torch.einsum(
180
+ 'st,csxy->ctxy', convert_mat_in, state_dict['conv_in.weight'])
181
+ load_state_dict(self, state_dict, logger=logger)
182
+
183
+ def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
184
+ dtype = sample.dtype
185
+ return super().forward(
186
+ sample, timestep, encoder_hidden_states, return_dict=False, **kwargs)[0].to(dtype)
187
+
188
+ def forward_enc(self, sample, timestep, encoder_hidden_states, **kwargs):
189
+ return unet_enc(self, sample, timestep, encoder_hidden_states, **kwargs)
190
+
191
+ def forward_dec(self, emb, down_block_res_samples, sample, encoder_hidden_states, **kwargs):
192
+ return unet_dec(self, emb, down_block_res_samples, sample, encoder_hidden_states, **kwargs)