.gitattributes CHANGED
@@ -33,13 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- celeb_meme.jpg filter=lfs diff=lfs merge=lfs -text
37
- cookie.png filter=lfs diff=lfs merge=lfs -text
38
- leather.jpg filter=lfs diff=lfs merge=lfs -text
39
- no_cookie.png filter=lfs diff=lfs merge=lfs -text
40
- poster_orig.jpg filter=lfs diff=lfs merge=lfs -text
41
- examples[[:space:]]2/celeb_meme.jpg filter=lfs diff=lfs merge=lfs -text
42
- examples[[:space:]]2/cookie.png filter=lfs diff=lfs merge=lfs -text
43
- examples[[:space:]]2/leather.jpg filter=lfs diff=lfs merge=lfs -text
44
- examples[[:space:]]2/no_cookie.png filter=lfs diff=lfs merge=lfs -text
45
- examples[[:space:]]2/poster_orig.jpg filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
__init__.cpython-310.pyc DELETED
Binary file (128 Bytes)
 
__init__.py DELETED
File without changes
app.py CHANGED
@@ -8,6 +8,7 @@ import spaces
8
  import time
9
  from pathlib import Path
10
 
 
11
  import gradio as gr
12
  import numpy as np
13
  import torch
@@ -60,7 +61,7 @@ def load_models(
60
  dit_path=None,
61
  ae_path=None,
62
  qwen2vl_model_path=None,
63
- device="cpu",
64
  max_length=256,
65
  dtype=torch.bfloat16,
66
  ):
@@ -117,7 +118,7 @@ class ImageGenerator:
117
  dit_path=None,
118
  ae_path=None,
119
  qwen2vl_model_path=None,
120
- device="cpu",
121
  max_length=640,
122
  dtype=torch.bfloat16,
123
  ) -> None:
@@ -134,9 +135,9 @@ class ImageGenerator:
134
  self.llm_encoder = self.llm_encoder.to(device=self.device, dtype=dtype)
135
 
136
  def to_cuda(self):
137
- self.ae.to(device='cpu', dtype=torch.float32)
138
- self.dit.to(device='cpu', dtype=torch.bfloat16)
139
- self.llm_encoder.to(device='cpu', dtype=torch.bfloat16)
140
 
141
  def prepare(self, prompt, img, ref_image, ref_image_raw):
142
  bs, _, h, w = img.shape
@@ -487,4 +488,5 @@ with gr.Blocks() as demo:
487
  fn=generate_examples,
488
  cache_examples=True
489
  )
 
490
  demo.launch()
 
8
  import time
9
  from pathlib import Path
10
 
11
+
12
  import gradio as gr
13
  import numpy as np
14
  import torch
 
61
  dit_path=None,
62
  ae_path=None,
63
  qwen2vl_model_path=None,
64
+ device="cuda",
65
  max_length=256,
66
  dtype=torch.bfloat16,
67
  ):
 
118
  dit_path=None,
119
  ae_path=None,
120
  qwen2vl_model_path=None,
121
+ device="cuda",
122
  max_length=640,
123
  dtype=torch.bfloat16,
124
  ) -> None:
 
135
  self.llm_encoder = self.llm_encoder.to(device=self.device, dtype=dtype)
136
 
137
  def to_cuda(self):
138
+ self.ae.to(device='cuda', dtype=torch.float32)
139
+ self.dit.to(device='cuda', dtype=torch.bfloat16)
140
+ self.llm_encoder.to(device='cuda', dtype=torch.bfloat16)
141
 
142
  def prepare(self, prompt, img, ref_image, ref_image_raw):
143
  bs, _, h, w = img.shape
 
488
  fn=generate_examples,
489
  cache_examples=True
490
  )
491
+
492
  demo.launch()
attention.cpython-310.pyc DELETED
Binary file (3.13 kB)
 
attention.py DELETED
@@ -1,133 +0,0 @@
1
- import math
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
-
7
- try:
8
- import flash_attn
9
- from flash_attn.flash_attn_interface import (
10
- _flash_attn_forward,
11
- flash_attn_func,
12
- flash_attn_varlen_func,
13
- )
14
- except ImportError:
15
- flash_attn = None
16
- flash_attn_varlen_func = None
17
- _flash_attn_forward = None
18
- flash_attn_func = None
19
-
20
- MEMORY_LAYOUT = {
21
- # flash模式:
22
- # 预处理: 输入 [batch_size, seq_len, num_heads, head_dim]
23
- # 后处理: 保持形状不变
24
- "flash": (
25
- lambda x: x, # 保持形状
26
- lambda x: x, # 保持形状
27
- ),
28
- # torch/vanilla模式:
29
- # 预处理: 交换序列和注意力头的维度 [B,S,A,D] -> [B,A,S,D]
30
- # 后处理: 交换回原始维度 [B,A,S,D] -> [B,S,A,D]
31
- "torch": (
32
- lambda x: x.transpose(1, 2), # (B,S,A,D) -> (B,A,S,D)
33
- lambda x: x.transpose(1, 2), # (B,A,S,D) -> (B,S,A,D)
34
- ),
35
- "vanilla": (
36
- lambda x: x.transpose(1, 2),
37
- lambda x: x.transpose(1, 2),
38
- ),
39
- }
40
-
41
-
42
- def attention(
43
- q,
44
- k,
45
- v,
46
- mode="torch",
47
- drop_rate=0,
48
- attn_mask=None,
49
- causal=False,
50
- ):
51
- """
52
- 执行QKV自注意力计算
53
-
54
- Args:
55
- q (torch.Tensor): 查询张量,形状 [batch_size, seq_len, num_heads, head_dim]
56
- k (torch.Tensor): 键张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
57
- v (torch.Tensor): 值张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
58
- mode (str): 注意力模式,可选 'flash', 'torch', 'vanilla'
59
- drop_rate (float): 注意力矩阵的dropout概率
60
- attn_mask (torch.Tensor): 注意力掩码,形状根据模式不同而变化
61
- causal (bool): 是否使用因果注意力(仅关注前面位置)
62
-
63
- Returns:
64
- torch.Tensor: 注意力输出,形状 [batch_size, seq_len, num_heads * head_dim]
65
- """
66
- # 获取预处理和后处理函数
67
- pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
68
-
69
- # 应用预处理变换
70
- q = pre_attn_layout(q) # 形状根据模式变化
71
- k = pre_attn_layout(k)
72
- v = pre_attn_layout(v)
73
-
74
- if mode == "torch":
75
- # 使用PyTorch原生的scaled_dot_product_attention
76
- if attn_mask is not None and attn_mask.dtype != torch.bool:
77
- attn_mask = attn_mask.to(q.dtype)
78
- x = F.scaled_dot_product_attention(
79
- q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
80
- )
81
- elif mode == "flash":
82
- assert flash_attn_func is not None, "flash_attn_func未定义"
83
- assert attn_mask is None, "不支持的注意力掩码"
84
- x: torch.Tensor = flash_attn_func(
85
- q, k, v, dropout_p=drop_rate, causal=causal, softmax_scale=None
86
- ) # type: ignore
87
- elif mode == "vanilla":
88
- # 手动实现注意力机制
89
- scale_factor = 1 / math.sqrt(q.size(-1)) # 缩放因子 1/sqrt(d_k)
90
-
91
- b, a, s, _ = q.shape # 获取形状参数
92
- s1 = k.size(2) # 键值序列长度
93
-
94
- # 初始化注意力偏置
95
- attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
96
-
97
- # 处理因果掩码
98
- if causal:
99
- assert attn_mask is None, "因果掩码和注意力掩码不能同时使用"
100
- # 生成下三角因果掩码
101
- temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
102
- diagonal=0
103
- )
104
- attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
105
- attn_bias = attn_bias.to(q.dtype)
106
-
107
- # 处理自定义注意力掩码
108
- if attn_mask is not None:
109
- if attn_mask.dtype == torch.bool:
110
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
111
- else:
112
- attn_bias += attn_mask # 允许类似ALiBi的位置偏置
113
-
114
- # 计算注意力矩阵
115
- attn = (q @ k.transpose(-2, -1)) * scale_factor # [B,A,S,S1]
116
- attn += attn_bias
117
-
118
- # softmax和dropout
119
- attn = attn.softmax(dim=-1)
120
- attn = torch.dropout(attn, p=drop_rate, train=True)
121
-
122
- # 计算输出
123
- x = attn @ v # [B,A,S,D]
124
- else:
125
- raise NotImplementedError(f"不支持的注意力模式: {mode}")
126
-
127
- # 应用后处理变换
128
- x = post_attn_layout(x) # 恢复原始维度顺序
129
-
130
- # 合并注意力头维度
131
- b, s, a, d = x.shape
132
- out = x.reshape(b, s, -1) # [B,S,A*D]
133
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
autoencoder.cpython-310.pyc DELETED
Binary file (8.78 kB)
 
autoencoder.py DELETED
@@ -1,326 +0,0 @@
1
- # Modified from Flux
2
- #
3
- # Copyright 2024 Black Forest Labs
4
-
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
-
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
-
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- #
17
- # This source code is licensed under the license found in the
18
- # LICENSE file in the root directory of this source tree.
19
- import torch
20
- from einops import rearrange
21
- from torch import Tensor, nn
22
-
23
-
24
- def swish(x: Tensor) -> Tensor:
25
- return x * torch.sigmoid(x)
26
-
27
-
28
- class AttnBlock(nn.Module):
29
- def __init__(self, in_channels: int):
30
- super().__init__()
31
- self.in_channels = in_channels
32
-
33
- self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
34
-
35
- self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
- self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
37
- self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
38
- self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
39
-
40
- def attention(self, h_: Tensor) -> Tensor:
41
- h_ = self.norm(h_)
42
- q = self.q(h_)
43
- k = self.k(h_)
44
- v = self.v(h_)
45
-
46
- b, c, h, w = q.shape
47
- q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
48
- k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
49
- v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
50
- h_ = nn.functional.scaled_dot_product_attention(q, k, v)
51
-
52
- return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
53
-
54
- def forward(self, x: Tensor) -> Tensor:
55
- return x + self.proj_out(self.attention(x))
56
-
57
-
58
- class ResnetBlock(nn.Module):
59
- def __init__(self, in_channels: int, out_channels: int):
60
- super().__init__()
61
- self.in_channels = in_channels
62
- out_channels = in_channels if out_channels is None else out_channels
63
- self.out_channels = out_channels
64
-
65
- self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
66
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
67
- self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
68
- self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
69
- if self.in_channels != self.out_channels:
70
- self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
71
-
72
- def forward(self, x):
73
- h = x
74
- h = self.norm1(h)
75
- h = swish(h)
76
- h = self.conv1(h)
77
-
78
- h = self.norm2(h)
79
- h = swish(h)
80
- h = self.conv2(h)
81
-
82
- if self.in_channels != self.out_channels:
83
- x = self.nin_shortcut(x)
84
-
85
- return x + h
86
-
87
-
88
- class Downsample(nn.Module):
89
- def __init__(self, in_channels: int):
90
- super().__init__()
91
- # no asymmetric padding in torch conv, must do it ourselves
92
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
93
-
94
- def forward(self, x: Tensor):
95
- pad = (0, 1, 0, 1)
96
- x = nn.functional.pad(x, pad, mode="constant", value=0)
97
- x = self.conv(x)
98
- return x
99
-
100
-
101
- class Upsample(nn.Module):
102
- def __init__(self, in_channels: int):
103
- super().__init__()
104
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
105
-
106
- def forward(self, x: Tensor):
107
- x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
108
- x = self.conv(x)
109
- return x
110
-
111
-
112
- class Encoder(nn.Module):
113
- def __init__(
114
- self,
115
- resolution: int,
116
- in_channels: int,
117
- ch: int,
118
- ch_mult: list[int],
119
- num_res_blocks: int,
120
- z_channels: int,
121
- ):
122
- super().__init__()
123
- self.ch = ch
124
- self.num_resolutions = len(ch_mult)
125
- self.num_res_blocks = num_res_blocks
126
- self.resolution = resolution
127
- self.in_channels = in_channels
128
- # downsampling
129
- self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
130
-
131
- curr_res = resolution
132
- in_ch_mult = (1, *tuple(ch_mult))
133
- self.in_ch_mult = in_ch_mult
134
- self.down = nn.ModuleList()
135
- block_in = self.ch
136
- for i_level in range(self.num_resolutions):
137
- block = nn.ModuleList()
138
- attn = nn.ModuleList()
139
- block_in = ch * in_ch_mult[i_level]
140
- block_out = ch * ch_mult[i_level]
141
- for _ in range(self.num_res_blocks):
142
- block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
143
- block_in = block_out
144
- down = nn.Module()
145
- down.block = block
146
- down.attn = attn
147
- if i_level != self.num_resolutions - 1:
148
- down.downsample = Downsample(block_in)
149
- curr_res = curr_res // 2
150
- self.down.append(down)
151
-
152
- # middle
153
- self.mid = nn.Module()
154
- self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
155
- self.mid.attn_1 = AttnBlock(block_in)
156
- self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
157
-
158
- # end
159
- self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
160
- self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
161
-
162
- def forward(self, x: Tensor) -> Tensor:
163
- # downsampling
164
- hs = [self.conv_in(x)]
165
- for i_level in range(self.num_resolutions):
166
- for i_block in range(self.num_res_blocks):
167
- h = self.down[i_level].block[i_block](hs[-1])
168
- if len(self.down[i_level].attn) > 0:
169
- h = self.down[i_level].attn[i_block](h)
170
- hs.append(h)
171
- if i_level != self.num_resolutions - 1:
172
- hs.append(self.down[i_level].downsample(hs[-1]))
173
-
174
- # middle
175
- h = hs[-1]
176
- h = self.mid.block_1(h)
177
- h = self.mid.attn_1(h)
178
- h = self.mid.block_2(h)
179
- # end
180
- h = self.norm_out(h)
181
- h = swish(h)
182
- h = self.conv_out(h)
183
- return h
184
-
185
-
186
- class Decoder(nn.Module):
187
- def __init__(
188
- self,
189
- ch: int,
190
- out_ch: int,
191
- ch_mult: list[int],
192
- num_res_blocks: int,
193
- in_channels: int,
194
- resolution: int,
195
- z_channels: int,
196
- ):
197
- super().__init__()
198
- self.ch = ch
199
- self.num_resolutions = len(ch_mult)
200
- self.num_res_blocks = num_res_blocks
201
- self.resolution = resolution
202
- self.in_channels = in_channels
203
- self.ffactor = 2 ** (self.num_resolutions - 1)
204
-
205
- # compute in_ch_mult, block_in and curr_res at lowest res
206
- block_in = ch * ch_mult[self.num_resolutions - 1]
207
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
208
- self.z_shape = (1, z_channels, curr_res, curr_res)
209
-
210
- # z to block_in
211
- self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
212
-
213
- # middle
214
- self.mid = nn.Module()
215
- self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
216
- self.mid.attn_1 = AttnBlock(block_in)
217
- self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
218
-
219
- # upsampling
220
- self.up = nn.ModuleList()
221
- for i_level in reversed(range(self.num_resolutions)):
222
- block = nn.ModuleList()
223
- attn = nn.ModuleList()
224
- block_out = ch * ch_mult[i_level]
225
- for _ in range(self.num_res_blocks + 1):
226
- block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
227
- block_in = block_out
228
- up = nn.Module()
229
- up.block = block
230
- up.attn = attn
231
- if i_level != 0:
232
- up.upsample = Upsample(block_in)
233
- curr_res = curr_res * 2
234
- self.up.insert(0, up) # prepend to get consistent order
235
-
236
- # end
237
- self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
238
- self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
239
-
240
- def forward(self, z: Tensor) -> Tensor:
241
- # z to block_in
242
- h = self.conv_in(z)
243
-
244
- # middle
245
- h = self.mid.block_1(h)
246
- h = self.mid.attn_1(h)
247
- h = self.mid.block_2(h)
248
-
249
- # upsampling
250
- for i_level in reversed(range(self.num_resolutions)):
251
- for i_block in range(self.num_res_blocks + 1):
252
- h = self.up[i_level].block[i_block](h)
253
- if len(self.up[i_level].attn) > 0:
254
- h = self.up[i_level].attn[i_block](h)
255
- if i_level != 0:
256
- h = self.up[i_level].upsample(h)
257
-
258
- # end
259
- h = self.norm_out(h)
260
- h = swish(h)
261
- h = self.conv_out(h)
262
- return h
263
-
264
-
265
- class DiagonalGaussian(nn.Module):
266
- def __init__(self, sample: bool = True, chunk_dim: int = 1):
267
- super().__init__()
268
- self.sample = sample
269
- self.chunk_dim = chunk_dim
270
-
271
- def forward(self, z: Tensor) -> Tensor:
272
- mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
273
- if self.sample:
274
- std = torch.exp(0.5 * logvar)
275
- return mean + std * torch.randn_like(mean)
276
- else:
277
- return mean
278
-
279
-
280
- class AutoEncoder(nn.Module):
281
- def __init__(
282
- self,
283
- resolution: int,
284
- in_channels: int,
285
- ch: int,
286
- out_ch: int,
287
- ch_mult: list[int],
288
- num_res_blocks: int,
289
- z_channels: int,
290
- scale_factor: float,
291
- shift_factor: float,
292
- ):
293
- super().__init__()
294
- self.encoder = Encoder(
295
- resolution=resolution,
296
- in_channels=in_channels,
297
- ch=ch,
298
- ch_mult=ch_mult,
299
- num_res_blocks=num_res_blocks,
300
- z_channels=z_channels,
301
- )
302
- self.decoder = Decoder(
303
- resolution=resolution,
304
- in_channels=in_channels,
305
- ch=ch,
306
- out_ch=out_ch,
307
- ch_mult=ch_mult,
308
- num_res_blocks=num_res_blocks,
309
- z_channels=z_channels,
310
- )
311
- self.reg = DiagonalGaussian()
312
-
313
- self.scale_factor = scale_factor
314
- self.shift_factor = shift_factor
315
-
316
- def encode(self, x: Tensor) -> Tensor:
317
- z = self.reg(self.encoder(x))
318
- z = self.scale_factor * (z - self.shift_factor)
319
- return z
320
-
321
- def decode(self, z: Tensor) -> Tensor:
322
- z = z / self.scale_factor + self.shift_factor
323
- return self.decoder(z)
324
-
325
- def forward(self, x: Tensor) -> Tensor:
326
- return self.decode(self.encode(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
conditioner.cpython-310.pyc DELETED
Binary file (4.94 kB)
 
conditioner.py DELETED
@@ -1,216 +0,0 @@
1
- import torch
2
- from qwen_vl_utils import process_vision_info
3
- from transformers import (
4
- AutoProcessor,
5
- Qwen2VLForConditionalGeneration,
6
- Qwen2_5_VLForConditionalGeneration,
7
- )
8
- from torchvision.transforms import ToPILImage
9
-
10
- to_pil = ToPILImage()
11
-
12
- Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
13
- - If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
14
- - If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
15
- Here are examples of how to transform or refine prompts:
16
- - User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
17
- - User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
18
- Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
19
- User Prompt:'''
20
-
21
-
22
- def split_string(s):
23
- # 将中文引号替换为英文引号
24
- s = s.replace("“", '"').replace("”", '"') # use english quotes
25
- result = []
26
- # 标记是否在引号内
27
- in_quotes = False
28
- temp = ""
29
-
30
- # 遍历字符串中的每个字符及其索引
31
- for idx, char in enumerate(s):
32
- # 如果字符是引号且索引大于 155
33
- if char == '"' and idx > 155:
34
- # 将引号添加到临时字符串
35
- temp += char
36
- # 如果不在引号内
37
- if not in_quotes:
38
- # 将临时字符串添加到结果列表
39
- result.append(temp)
40
- # 清空临时字符串
41
- temp = ""
42
-
43
- # 切换引号状态
44
- in_quotes = not in_quotes
45
- continue
46
- # 如果在引号内
47
- if in_quotes:
48
- # 如果字符是空格
49
- if char.isspace():
50
- pass # have space token
51
-
52
- # 将字符用中文引号包裹后添加到结果列表
53
- result.append("“" + char + "”")
54
- else:
55
- # 将字符添加到临时字符串
56
- temp += char
57
-
58
- # 如果临时字符串不为空
59
- if temp:
60
- # 将临时字符串添加到结果列表
61
- result.append(temp)
62
-
63
- return result
64
-
65
-
66
- class Qwen25VL_7b_Embedder(torch.nn.Module):
67
- def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"):
68
- super(Qwen25VL_7b_Embedder, self).__init__()
69
- self.max_length = max_length
70
- self.dtype = dtype
71
- self.device = device
72
-
73
- self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
74
- model_path,
75
- torch_dtype=dtype,
76
- attn_implementation="eager",
77
- ).to(torch.cuda.current_device())
78
-
79
- self.model.requires_grad_(False)
80
- self.processor = AutoProcessor.from_pretrained(
81
- model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28
82
- )
83
-
84
- self.prefix = Qwen25VL_7b_PREFIX
85
-
86
- def forward(self, caption, ref_images):
87
- text_list = caption
88
- embs = torch.zeros(
89
- len(text_list),
90
- self.max_length,
91
- self.model.config.hidden_size,
92
- dtype=torch.bfloat16,
93
- device=torch.cuda.current_device(),
94
- )
95
- hidden_states = torch.zeros(
96
- len(text_list),
97
- self.max_length,
98
- self.model.config.hidden_size,
99
- dtype=torch.bfloat16,
100
- device=torch.cuda.current_device(),
101
- )
102
- masks = torch.zeros(
103
- len(text_list),
104
- self.max_length,
105
- dtype=torch.long,
106
- device=torch.cuda.current_device(),
107
- )
108
- input_ids_list = []
109
- attention_mask_list = []
110
- emb_list = []
111
-
112
- def split_string(s):
113
- s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes
114
- result = []
115
- in_quotes = False
116
- temp = ""
117
-
118
- for idx,char in enumerate(s):
119
- if char == '"' and idx>155:
120
- temp += char
121
- if not in_quotes:
122
- result.append(temp)
123
- temp = ""
124
-
125
- in_quotes = not in_quotes
126
- continue
127
- if in_quotes:
128
- if char.isspace():
129
- pass # have space token
130
-
131
- result.append("“" + char + "”")
132
- else:
133
- temp += char
134
-
135
- if temp:
136
- result.append(temp)
137
-
138
- return result
139
-
140
- for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):
141
-
142
- messages = [{"role": "user", "content": []}]
143
-
144
- messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"})
145
-
146
- messages[0]["content"].append({"type": "image", "image": to_pil(imgs)})
147
-
148
- # 再添加 text
149
- messages[0]["content"].append({"type": "text", "text": f"{txt}"})
150
-
151
- # Preparation for inference
152
- text = self.processor.apply_chat_template(
153
- messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
154
- )
155
-
156
- image_inputs, video_inputs = process_vision_info(messages)
157
-
158
- inputs = self.processor(
159
- text=[text],
160
- images=image_inputs,
161
- padding=True,
162
- return_tensors="pt",
163
- )
164
-
165
- old_inputs_ids = inputs.input_ids
166
- text_split_list = split_string(text)
167
-
168
- token_list = []
169
- for text_each in text_split_list:
170
- txt_inputs = self.processor(
171
- text=text_each,
172
- images=None,
173
- videos=None,
174
- padding=True,
175
- return_tensors="pt",
176
- )
177
- token_each = txt_inputs.input_ids
178
- if token_each[0][0] == 2073 and token_each[0][-1] == 854:
179
- token_each = token_each[:, 1:-1]
180
- token_list.append(token_each)
181
- else:
182
- token_list.append(token_each)
183
-
184
- new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
185
-
186
- new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
187
-
188
- idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]
189
- idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]
190
- inputs.input_ids = (
191
- torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
192
- .unsqueeze(0)
193
- .to("cuda")
194
- )
195
- inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
196
- outputs = self.model(
197
- input_ids=inputs.input_ids,
198
- attention_mask=inputs.attention_mask,
199
- pixel_values=inputs.pixel_values.to("cuda"),
200
- image_grid_thw=inputs.image_grid_thw.to("cuda"),
201
- output_hidden_states=True,
202
- )
203
-
204
- emb = outputs["hidden_states"][-1]
205
-
206
- embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
207
- : self.max_length
208
- ]
209
-
210
- masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
211
- (min(self.max_length, emb.shape[1] - 217)),
212
- dtype=torch.long,
213
- device=torch.cuda.current_device(),
214
- )
215
-
216
- return embs, masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
connector_edit.cpython-310.pyc DELETED
Binary file (11.8 kB)
 
connector_edit.py DELETED
@@ -1,486 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- import torch.nn
5
- from einops import rearrange
6
- from torch import nn
7
-
8
- from .layers import MLP, TextProjection, TimestepEmbedder, apply_gate, attention
9
-
10
-
11
- class RMSNorm(nn.Module):
12
- def __init__(
13
- self,
14
- dim: int,
15
- elementwise_affine=True,
16
- eps: float = 1e-6,
17
- device=None,
18
- dtype=None,
19
- ):
20
- """
21
- Initialize the RMSNorm normalization layer.
22
-
23
- Args:
24
- dim (int): The dimension of the input tensor.
25
- eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
26
-
27
- Attributes:
28
- eps (float): A small value added to the denominator for numerical stability.
29
- weight (nn.Parameter): Learnable scaling parameter.
30
-
31
- """
32
- factory_kwargs = {"device": device, "dtype": dtype}
33
- super().__init__()
34
- self.eps = eps
35
- if elementwise_affine:
36
- self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
37
-
38
- def _norm(self, x):
39
- """
40
- Apply the RMSNorm normalization to the input tensor.
41
-
42
- Args:
43
- x (torch.Tensor): The input tensor.
44
-
45
- Returns:
46
- torch.Tensor: The normalized tensor.
47
-
48
- """
49
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
50
-
51
- def forward(self, x):
52
- """
53
- Forward pass through the RMSNorm layer.
54
-
55
- Args:
56
- x (torch.Tensor): The input tensor.
57
-
58
- Returns:
59
- torch.Tensor: The output tensor after applying RMSNorm.
60
-
61
- """
62
- output = self._norm(x.float()).type_as(x)
63
- if hasattr(self, "weight"):
64
- output = output * self.weight
65
- return output
66
-
67
-
68
- def get_norm_layer(norm_layer):
69
- """
70
- Get the normalization layer.
71
-
72
- Args:
73
- norm_layer (str): The type of normalization layer.
74
-
75
- Returns:
76
- norm_layer (nn.Module): The normalization layer.
77
- """
78
- if norm_layer == "layer":
79
- return nn.LayerNorm
80
- elif norm_layer == "rms":
81
- return RMSNorm
82
- else:
83
- raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
84
-
85
-
86
- def get_activation_layer(act_type):
87
- """get activation layer
88
-
89
- Args:
90
- act_type (str): the activation type
91
-
92
- Returns:
93
- torch.nn.functional: the activation layer
94
- """
95
- if act_type == "gelu":
96
- return lambda: nn.GELU()
97
- elif act_type == "gelu_tanh":
98
- return lambda: nn.GELU(approximate="tanh")
99
- elif act_type == "relu":
100
- return nn.ReLU
101
- elif act_type == "silu":
102
- return nn.SiLU
103
- else:
104
- raise ValueError(f"Unknown activation type: {act_type}")
105
-
106
- class IndividualTokenRefinerBlock(torch.nn.Module):
107
- def __init__(
108
- self,
109
- hidden_size,
110
- heads_num,
111
- mlp_width_ratio: str = 4.0,
112
- mlp_drop_rate: float = 0.0,
113
- act_type: str = "silu",
114
- qk_norm: bool = False,
115
- qk_norm_type: str = "layer",
116
- qkv_bias: bool = True,
117
- need_CA: bool = False,
118
- dtype: Optional[torch.dtype] = None,
119
- device: Optional[torch.device] = None,
120
- ):
121
- factory_kwargs = {"device": device, "dtype": dtype}
122
- super().__init__()
123
- self.need_CA = need_CA
124
- self.heads_num = heads_num
125
- head_dim = hidden_size // heads_num
126
- mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
127
-
128
- self.norm1 = nn.LayerNorm(
129
- hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
130
- )
131
- self.self_attn_qkv = nn.Linear(
132
- hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
133
- )
134
- qk_norm_layer = get_norm_layer(qk_norm_type)
135
- self.self_attn_q_norm = (
136
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
137
- if qk_norm
138
- else nn.Identity()
139
- )
140
- self.self_attn_k_norm = (
141
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
142
- if qk_norm
143
- else nn.Identity()
144
- )
145
- self.self_attn_proj = nn.Linear(
146
- hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
147
- )
148
-
149
- self.norm2 = nn.LayerNorm(
150
- hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
151
- )
152
- act_layer = get_activation_layer(act_type)
153
- self.mlp = MLP(
154
- in_channels=hidden_size,
155
- hidden_channels=mlp_hidden_dim,
156
- act_layer=act_layer,
157
- drop=mlp_drop_rate,
158
- **factory_kwargs,
159
- )
160
-
161
- self.adaLN_modulation = nn.Sequential(
162
- act_layer(),
163
- nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
164
- )
165
-
166
- if self.need_CA:
167
- self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size,
168
- heads_num=heads_num,
169
- mlp_width_ratio=mlp_width_ratio,
170
- mlp_drop_rate=mlp_drop_rate,
171
- act_type=act_type,
172
- qk_norm=qk_norm,
173
- qk_norm_type=qk_norm_type,
174
- qkv_bias=qkv_bias,
175
- **factory_kwargs,)
176
- # Zero-initialize the modulation
177
- nn.init.zeros_(self.adaLN_modulation[1].weight)
178
- nn.init.zeros_(self.adaLN_modulation[1].bias)
179
-
180
- def forward(
181
- self,
182
- x: torch.Tensor,
183
- c: torch.Tensor, # timestep_aware_representations + context_aware_representations
184
- attn_mask: torch.Tensor = None,
185
- y: torch.Tensor = None,
186
- ):
187
- gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
188
-
189
- norm_x = self.norm1(x)
190
- qkv = self.self_attn_qkv(norm_x)
191
- q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
192
- # Apply QK-Norm if needed
193
- q = self.self_attn_q_norm(q).to(v)
194
- k = self.self_attn_k_norm(k).to(v)
195
-
196
- # Self-Attention
197
- attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
198
-
199
- x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
200
-
201
- if self.need_CA:
202
- x = self.cross_attnblock(x, c, attn_mask, y)
203
-
204
- # FFN Layer
205
- x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
206
-
207
- return x
208
-
209
-
210
-
211
-
212
- class CrossAttnBlock(torch.nn.Module):
213
- def __init__(
214
- self,
215
- hidden_size,
216
- heads_num,
217
- mlp_width_ratio: str = 4.0,
218
- mlp_drop_rate: float = 0.0,
219
- act_type: str = "silu",
220
- qk_norm: bool = False,
221
- qk_norm_type: str = "layer",
222
- qkv_bias: bool = True,
223
- dtype: Optional[torch.dtype] = None,
224
- device: Optional[torch.device] = None,
225
- ):
226
- factory_kwargs = {"device": device, "dtype": dtype}
227
- super().__init__()
228
- self.heads_num = heads_num
229
- head_dim = hidden_size // heads_num
230
-
231
- self.norm1 = nn.LayerNorm(
232
- hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
233
- )
234
- self.norm1_2 = nn.LayerNorm(
235
- hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
236
- )
237
- self.self_attn_q = nn.Linear(
238
- hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
239
- )
240
- self.self_attn_kv = nn.Linear(
241
- hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs
242
- )
243
- qk_norm_layer = get_norm_layer(qk_norm_type)
244
- self.self_attn_q_norm = (
245
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
246
- if qk_norm
247
- else nn.Identity()
248
- )
249
- self.self_attn_k_norm = (
250
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
251
- if qk_norm
252
- else nn.Identity()
253
- )
254
- self.self_attn_proj = nn.Linear(
255
- hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
256
- )
257
-
258
- self.norm2 = nn.LayerNorm(
259
- hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
260
- )
261
- act_layer = get_activation_layer(act_type)
262
-
263
- self.adaLN_modulation = nn.Sequential(
264
- act_layer(),
265
- nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
266
- )
267
- # Zero-initialize the modulation
268
- nn.init.zeros_(self.adaLN_modulation[1].weight)
269
- nn.init.zeros_(self.adaLN_modulation[1].bias)
270
-
271
- def forward(
272
- self,
273
- x: torch.Tensor,
274
- c: torch.Tensor, # timestep_aware_representations + context_aware_representations
275
- attn_mask: torch.Tensor = None,
276
- y: torch.Tensor=None,
277
-
278
- ):
279
- gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
280
-
281
- norm_x = self.norm1(x)
282
- norm_y = self.norm1_2(y)
283
- q = self.self_attn_q(norm_x)
284
- q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
285
- kv = self.self_attn_kv(norm_y)
286
- k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num)
287
- # Apply QK-Norm if needed
288
- q = self.self_attn_q_norm(q).to(v)
289
- k = self.self_attn_k_norm(k).to(v)
290
-
291
- # Self-Attention
292
- attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
293
-
294
- x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
295
-
296
- return x
297
-
298
-
299
-
300
- class IndividualTokenRefiner(torch.nn.Module):
301
- def __init__(
302
- self,
303
- hidden_size,
304
- heads_num,
305
- depth,
306
- mlp_width_ratio: float = 4.0,
307
- mlp_drop_rate: float = 0.0,
308
- act_type: str = "silu",
309
- qk_norm: bool = False,
310
- qk_norm_type: str = "layer",
311
- qkv_bias: bool = True,
312
- need_CA:bool=False,
313
- dtype: Optional[torch.dtype] = None,
314
- device: Optional[torch.device] = None,
315
- ):
316
-
317
- factory_kwargs = {"device": device, "dtype": dtype}
318
- super().__init__()
319
- self.need_CA = need_CA
320
- self.blocks = nn.ModuleList(
321
- [
322
- IndividualTokenRefinerBlock(
323
- hidden_size=hidden_size,
324
- heads_num=heads_num,
325
- mlp_width_ratio=mlp_width_ratio,
326
- mlp_drop_rate=mlp_drop_rate,
327
- act_type=act_type,
328
- qk_norm=qk_norm,
329
- qk_norm_type=qk_norm_type,
330
- qkv_bias=qkv_bias,
331
- need_CA=self.need_CA,
332
- **factory_kwargs,
333
- )
334
- for _ in range(depth)
335
- ]
336
- )
337
-
338
-
339
- def forward(
340
- self,
341
- x: torch.Tensor,
342
- c: torch.LongTensor,
343
- mask: Optional[torch.Tensor] = None,
344
- y:torch.Tensor=None,
345
- ):
346
- self_attn_mask = None
347
- if mask is not None:
348
- batch_size = mask.shape[0]
349
- seq_len = mask.shape[1]
350
- mask = mask.to(x.device)
351
- # batch_size x 1 x seq_len x seq_len
352
- self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
353
- 1, 1, seq_len, 1
354
- )
355
- # batch_size x 1 x seq_len x seq_len
356
- self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
357
- # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
358
- self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
359
- # avoids self-attention weight being NaN for padding tokens
360
- self_attn_mask[:, :, :, 0] = True
361
-
362
-
363
- for block in self.blocks:
364
- x = block(x, c, self_attn_mask,y)
365
-
366
- return x
367
-
368
-
369
- class SingleTokenRefiner(torch.nn.Module):
370
- """
371
- A single token refiner block for llm text embedding refine.
372
- """
373
- def __init__(
374
- self,
375
- in_channels,
376
- hidden_size,
377
- heads_num,
378
- depth,
379
- mlp_width_ratio: float = 4.0,
380
- mlp_drop_rate: float = 0.0,
381
- act_type: str = "silu",
382
- qk_norm: bool = False,
383
- qk_norm_type: str = "layer",
384
- qkv_bias: bool = True,
385
- need_CA:bool=False,
386
- attn_mode: str = "torch",
387
- dtype: Optional[torch.dtype] = None,
388
- device: Optional[torch.device] = None,
389
- ):
390
- factory_kwargs = {"device": device, "dtype": dtype}
391
- super().__init__()
392
- self.attn_mode = attn_mode
393
- self.need_CA = need_CA
394
- assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
395
-
396
- self.input_embedder = nn.Linear(
397
- in_channels, hidden_size, bias=True, **factory_kwargs
398
- )
399
- if self.need_CA:
400
- self.input_embedder_CA = nn.Linear(
401
- in_channels, hidden_size, bias=True, **factory_kwargs
402
- )
403
-
404
- act_layer = get_activation_layer(act_type)
405
- # Build timestep embedding layer
406
- self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
407
- # Build context embedding layer
408
- self.c_embedder = TextProjection(
409
- in_channels, hidden_size, act_layer, **factory_kwargs
410
- )
411
-
412
- self.individual_token_refiner = IndividualTokenRefiner(
413
- hidden_size=hidden_size,
414
- heads_num=heads_num,
415
- depth=depth,
416
- mlp_width_ratio=mlp_width_ratio,
417
- mlp_drop_rate=mlp_drop_rate,
418
- act_type=act_type,
419
- qk_norm=qk_norm,
420
- qk_norm_type=qk_norm_type,
421
- qkv_bias=qkv_bias,
422
- need_CA=need_CA,
423
- **factory_kwargs,
424
- )
425
-
426
- def forward(
427
- self,
428
- x: torch.Tensor,
429
- t: torch.LongTensor,
430
- mask: Optional[torch.LongTensor] = None,
431
- y: torch.LongTensor=None,
432
- ):
433
- timestep_aware_representations = self.t_embedder(t)
434
-
435
- if mask is None:
436
- context_aware_representations = x.mean(dim=1)
437
- else:
438
- mask_float = mask.unsqueeze(-1) # [b, s1, 1]
439
- context_aware_representations = (x * mask_float).sum(
440
- dim=1
441
- ) / mask_float.sum(dim=1)
442
- context_aware_representations = self.c_embedder(context_aware_representations)
443
- c = timestep_aware_representations + context_aware_representations
444
-
445
- x = self.input_embedder(x)
446
- if self.need_CA:
447
- y = self.input_embedder_CA(y)
448
- x = self.individual_token_refiner(x, c, mask, y)
449
- else:
450
- x = self.individual_token_refiner(x, c, mask)
451
-
452
- return x
453
-
454
-
455
-
456
- class Qwen2Connector(torch.nn.Module):
457
- def __init__(
458
- self,
459
- # biclip_dim=1024,
460
- in_channels=3584,
461
- hidden_size=4096,
462
- heads_num=32,
463
- depth=2,
464
- need_CA=False,
465
- device=None,
466
- dtype=torch.bfloat16,
467
- ):
468
- super().__init__()
469
- factory_kwargs = {"device": device, "dtype":dtype}
470
-
471
- self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs)
472
- self.global_proj_out=nn.Linear(in_channels,768)
473
-
474
- self.scale_factor = nn.Parameter(torch.zeros(1))
475
- with torch.no_grad():
476
- self.scale_factor.data += -(1 - 0.09)
477
-
478
- def forward(self, x,t,mask):
479
- mask_float = mask.unsqueeze(-1) # [b, s1, 1]
480
- x_mean = (x * mask_float).sum(
481
- dim=1
482
- ) / mask_float.sum(dim=1) * (1 + self.scale_factor)
483
-
484
- global_out=self.global_proj_out(x_mean)
485
- encoder_hidden_states = self.S(x,t,mask)
486
- return encoder_hidden_states,global_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cookie.png → examples 2.zip RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:328bf4f4779cd6235016a217eaf5dc1ef7a8f1cb95e8fbd7ee538ac6824e75b0
3
- size 542518
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5de0f67d94e0e46599bc9619a912a05898b8053ddc0f1f6563a3ee3b4dd1f7c7
3
+ size 1878523
examples 2/celeb_meme.jpg DELETED

Git LFS Details

  • SHA256: 4fb2ccab4218dba753781d65e8f5933f8ab7613543b59a7b4512a6654fe55a4f
  • Pointer size: 131 Bytes
  • Size of remote file: 267 kB
examples 2/cookie.png DELETED

Git LFS Details

  • SHA256: 328bf4f4779cd6235016a217eaf5dc1ef7a8f1cb95e8fbd7ee538ac6824e75b0
  • Pointer size: 131 Bytes
  • Size of remote file: 543 kB
examples 2/ghibli_meme.jpg DELETED
Binary file (38.1 kB)
 
examples 2/leather.jpg DELETED

Git LFS Details

  • SHA256: efa1eab6d7fa83b2bb39631b194012cf01cca24356b624f32e0fd05346af3ec2
  • Pointer size: 131 Bytes
  • Size of remote file: 250 kB
examples 2/meme.jpg DELETED
Binary file (49.8 kB)
 
examples 2/no_cookie.png DELETED

Git LFS Details

  • SHA256: 4ee90a1e41774e2dae54ca436874341e750f2c7a6196b8360aee1952e98066f8
  • Pointer size: 131 Bytes
  • Size of remote file: 162 kB
examples 2/poster.jpg DELETED
Binary file (65.4 kB)
 
examples 2/poster_orig.jpg DELETED

Git LFS Details

  • SHA256: 92a4178a56e7fefd7dfd418c675c1ab6b6b2e00e17b45a778a1100ab62f9bfba
  • Pointer size: 131 Bytes
  • Size of remote file: 458 kB
ghibli_meme.jpg DELETED
Binary file (38.1 kB)
 
layers.cpython-310.pyc DELETED
Binary file (19.1 kB)
 
layers.py DELETED
@@ -1,640 +0,0 @@
1
- # Modified from Flux
2
- #
3
- # Copyright 2024 Black Forest Labs
4
-
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
-
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
-
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- #
17
- # This source code is licensed under the license found in the
18
- # LICENSE file in the root directory of this source tree.
19
-
20
- import math # noqa: I001
21
- from dataclasses import dataclass
22
- from functools import partial
23
-
24
- import torch
25
- import torch.nn.functional as F
26
- from einops import rearrange
27
- # from liger_kernel.ops.rms_norm import LigerRMSNormFunction
28
- from torch import Tensor, nn
29
-
30
-
31
- try:
32
- import flash_attn
33
- from flash_attn.flash_attn_interface import (
34
- _flash_attn_forward,
35
- flash_attn_varlen_func,
36
- )
37
- except ImportError:
38
- flash_attn = None
39
- flash_attn_varlen_func = None
40
- _flash_attn_forward = None
41
-
42
-
43
- MEMORY_LAYOUT = {
44
- "flash": (
45
- lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
46
- lambda x: x,
47
- ),
48
- "torch": (
49
- lambda x: x.transpose(1, 2),
50
- lambda x: x.transpose(1, 2),
51
- ),
52
- "vanilla": (
53
- lambda x: x.transpose(1, 2),
54
- lambda x: x.transpose(1, 2),
55
- ),
56
- }
57
-
58
-
59
- def attention(
60
- q,
61
- k,
62
- v,
63
- mode="torch",
64
- drop_rate=0,
65
- attn_mask=None,
66
- causal=False,
67
- cu_seqlens_q=None,
68
- cu_seqlens_kv=None,
69
- max_seqlen_q=None,
70
- max_seqlen_kv=None,
71
- batch_size=1,
72
- ):
73
- """
74
- Perform QKV self attention.
75
-
76
- Args:
77
- q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
78
- k (torch.Tensor): Key tensor with shape [b, s1, a, d]
79
- v (torch.Tensor): Value tensor with shape [b, s1, a, d]
80
- mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
81
- drop_rate (float): Dropout rate in attention map. (default: 0)
82
- attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
83
- (default: None)
84
- causal (bool): Whether to use causal attention. (default: False)
85
- cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
86
- used to index into q.
87
- cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
88
- used to index into kv.
89
- max_seqlen_q (int): The maximum sequence length in the batch of q.
90
- max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
91
-
92
- Returns:
93
- torch.Tensor: Output tensor after self attention with shape [b, s, ad]
94
- """
95
- pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
96
- q = pre_attn_layout(q)
97
- k = pre_attn_layout(k)
98
- v = pre_attn_layout(v)
99
-
100
- if mode == "torch":
101
- if attn_mask is not None and attn_mask.dtype != torch.bool:
102
- attn_mask = attn_mask.to(q.dtype)
103
- x = F.scaled_dot_product_attention(
104
- q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
105
- )
106
- elif mode == "flash":
107
- assert flash_attn_varlen_func is not None
108
- x: torch.Tensor = flash_attn_varlen_func(
109
- q,
110
- k,
111
- v,
112
- cu_seqlens_q,
113
- cu_seqlens_kv,
114
- max_seqlen_q,
115
- max_seqlen_kv,
116
- ) # type: ignore
117
- # x with shape [(bxs), a, d]
118
- x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # type: ignore # reshape x to [b, s, a, d]
119
- elif mode == "vanilla":
120
- scale_factor = 1 / math.sqrt(q.size(-1))
121
-
122
- b, a, s, _ = q.shape
123
- s1 = k.size(2)
124
- attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
125
- if causal:
126
- # Only applied to self attention
127
- assert attn_mask is None, (
128
- "Causal mask and attn_mask cannot be used together"
129
- )
130
- temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
131
- diagonal=0
132
- )
133
- attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
134
- attn_bias.to(q.dtype)
135
-
136
- if attn_mask is not None:
137
- if attn_mask.dtype == torch.bool:
138
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
139
- else:
140
- attn_bias += attn_mask
141
-
142
- # TODO: Maybe force q and k to be float32 to avoid numerical overflow
143
- attn = (q @ k.transpose(-2, -1)) * scale_factor
144
- attn += attn_bias
145
- attn = attn.softmax(dim=-1)
146
- attn = torch.dropout(attn, p=drop_rate, train=True)
147
- x = attn @ v
148
- else:
149
- raise NotImplementedError(f"Unsupported attention mode: {mode}")
150
-
151
- x = post_attn_layout(x)
152
- b, s, a, d = x.shape
153
- out = x.reshape(b, s, -1)
154
- return out
155
-
156
-
157
- def apply_gate(x, gate=None, tanh=False):
158
- """AI is creating summary for apply_gate
159
-
160
- Args:
161
- x (torch.Tensor): input tensor.
162
- gate (torch.Tensor, optional): gate tensor. Defaults to None.
163
- tanh (bool, optional): whether to use tanh function. Defaults to False.
164
-
165
- Returns:
166
- torch.Tensor: the output tensor after apply gate.
167
- """
168
- if gate is None:
169
- return x
170
- if tanh:
171
- return x * gate.unsqueeze(1).tanh()
172
- else:
173
- return x * gate.unsqueeze(1)
174
-
175
-
176
- class MLP(nn.Module):
177
- """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
178
-
179
- def __init__(
180
- self,
181
- in_channels,
182
- hidden_channels=None,
183
- out_features=None,
184
- act_layer=nn.GELU,
185
- norm_layer=None,
186
- bias=True,
187
- drop=0.0,
188
- use_conv=False,
189
- device=None,
190
- dtype=None,
191
- ):
192
- super().__init__()
193
- out_features = out_features or in_channels
194
- hidden_channels = hidden_channels or in_channels
195
- bias = (bias, bias)
196
- drop_probs = (drop, drop)
197
- linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
198
-
199
- self.fc1 = linear_layer(
200
- in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype
201
- )
202
- self.act = act_layer()
203
- self.drop1 = nn.Dropout(drop_probs[0])
204
- self.norm = (
205
- norm_layer(hidden_channels, device=device, dtype=dtype)
206
- if norm_layer is not None
207
- else nn.Identity()
208
- )
209
- self.fc2 = linear_layer(
210
- hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype
211
- )
212
- self.drop2 = nn.Dropout(drop_probs[1])
213
-
214
- def forward(self, x):
215
- x = self.fc1(x)
216
- x = self.act(x)
217
- x = self.drop1(x)
218
- x = self.norm(x)
219
- x = self.fc2(x)
220
- x = self.drop2(x)
221
- return x
222
-
223
-
224
- class TextProjection(nn.Module):
225
- """
226
- Projects text embeddings. Also handles dropout for classifier-free guidance.
227
-
228
- Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
229
- """
230
-
231
- def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
232
- factory_kwargs = {"dtype": dtype, "device": device}
233
- super().__init__()
234
- self.linear_1 = nn.Linear(
235
- in_features=in_channels,
236
- out_features=hidden_size,
237
- bias=True,
238
- **factory_kwargs,
239
- )
240
- self.act_1 = act_layer()
241
- self.linear_2 = nn.Linear(
242
- in_features=hidden_size,
243
- out_features=hidden_size,
244
- bias=True,
245
- **factory_kwargs,
246
- )
247
-
248
- def forward(self, caption):
249
- hidden_states = self.linear_1(caption)
250
- hidden_states = self.act_1(hidden_states)
251
- hidden_states = self.linear_2(hidden_states)
252
- return hidden_states
253
-
254
-
255
- class TimestepEmbedder(nn.Module):
256
- """
257
- Embeds scalar timesteps into vector representations.
258
- """
259
-
260
- def __init__(
261
- self,
262
- hidden_size,
263
- act_layer,
264
- frequency_embedding_size=256,
265
- max_period=10000,
266
- out_size=None,
267
- dtype=None,
268
- device=None,
269
- ):
270
- factory_kwargs = {"dtype": dtype, "device": device}
271
- super().__init__()
272
- self.frequency_embedding_size = frequency_embedding_size
273
- self.max_period = max_period
274
- if out_size is None:
275
- out_size = hidden_size
276
-
277
- self.mlp = nn.Sequential(
278
- nn.Linear(
279
- frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
280
- ),
281
- act_layer(),
282
- nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
283
- )
284
- nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore
285
- nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore
286
-
287
- @staticmethod
288
- def timestep_embedding(t, dim, max_period=10000):
289
- """
290
- Create sinusoidal timestep embeddings.
291
-
292
- Args:
293
- t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
294
- dim (int): the dimension of the output.
295
- max_period (int): controls the minimum frequency of the embeddings.
296
-
297
- Returns:
298
- embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
299
-
300
- .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
301
- """
302
- half = dim // 2
303
- freqs = torch.exp(
304
- -math.log(max_period)
305
- * torch.arange(start=0, end=half, dtype=torch.float32)
306
- / half
307
- ).to(device=t.device)
308
- args = t[:, None].float() * freqs[None]
309
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
310
- if dim % 2:
311
- embedding = torch.cat(
312
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
313
- )
314
- return embedding
315
-
316
- def forward(self, t):
317
- t_freq = self.timestep_embedding(
318
- t, self.frequency_embedding_size, self.max_period
319
- ).type(self.mlp[0].weight.dtype) # type: ignore
320
- t_emb = self.mlp(t_freq)
321
- return t_emb
322
-
323
-
324
- class EmbedND(nn.Module):
325
- def __init__(self, dim: int, theta: int, axes_dim: list[int]):
326
- super().__init__()
327
- self.dim = dim
328
- self.theta = theta
329
- self.axes_dim = axes_dim
330
-
331
- def forward(self, ids: Tensor) -> Tensor:
332
- n_axes = ids.shape[-1]
333
- emb = torch.cat(
334
- [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
335
- dim=-3,
336
- )
337
-
338
- return emb.unsqueeze(1)
339
-
340
-
341
- class MLPEmbedder(nn.Module):
342
- def __init__(self, in_dim: int, hidden_dim: int):
343
- super().__init__()
344
- self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
345
- self.silu = nn.SiLU()
346
- self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
347
-
348
- def forward(self, x: Tensor) -> Tensor:
349
- return self.out_layer(self.silu(self.in_layer(x)))
350
-
351
-
352
- def rope(pos, dim: int, theta: int):
353
- assert dim % 2 == 0
354
- scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
355
- omega = 1.0 / (theta**scale)
356
- out = torch.einsum("...n,d->...nd", pos, omega)
357
- out = torch.stack(
358
- [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
359
- )
360
- out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
361
- return out.float()
362
-
363
-
364
- def attention_after_rope(q, k, v, pe):
365
- q, k = apply_rope(q, k, pe)
366
-
367
- from .attention import attention
368
-
369
- x = attention(q, k, v, mode="torch")
370
- return x
371
-
372
-
373
- @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
374
- def apply_rope(xq, xk, freqs_cis):
375
- # 将 num_heads 和 seq_len 的维度交换回原函数的处理顺序
376
- xq = xq.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
377
- xk = xk.transpose(1, 2)
378
-
379
- # 将 head_dim 拆分为复数部分(实部和虚部)
380
- xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
381
- xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
382
-
383
- # 应用旋转位置编码(复数乘法)
384
- xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
385
- xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
386
-
387
- # 恢复张量形状并转置回目标维度顺序
388
- xq_out = xq_out.reshape(*xq.shape).type_as(xq).transpose(1, 2)
389
- xk_out = xk_out.reshape(*xk.shape).type_as(xk).transpose(1, 2)
390
-
391
- return xq_out, xk_out
392
-
393
-
394
- @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
395
- def scale_add_residual(
396
- x: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor
397
- ) -> torch.Tensor:
398
- return x * scale + residual
399
-
400
-
401
- @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
402
- def layernorm_and_scale_shift(
403
- x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor
404
- ) -> torch.Tensor:
405
- return torch.nn.functional.layer_norm(x, (x.size(-1),)) * (scale + 1) + shift
406
-
407
-
408
- class SelfAttention(nn.Module):
409
- def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
410
- super().__init__()
411
- self.num_heads = num_heads
412
- head_dim = dim // num_heads
413
-
414
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
415
- self.norm = QKNorm(head_dim)
416
- self.proj = nn.Linear(dim, dim)
417
-
418
- def forward(self, x: Tensor, pe: Tensor) -> Tensor:
419
- qkv = self.qkv(x)
420
- q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
421
- q, k = self.norm(q, k, v)
422
- x = attention_after_rope(q, k, v, pe=pe)
423
- x = self.proj(x)
424
- return x
425
-
426
-
427
- @dataclass
428
- class ModulationOut:
429
- shift: Tensor
430
- scale: Tensor
431
- gate: Tensor
432
-
433
-
434
- class RMSNorm(torch.nn.Module):
435
- def __init__(self, dim: int):
436
- super().__init__()
437
- self.scale = nn.Parameter(torch.ones(dim))
438
-
439
- # @staticmethod
440
- # def rms_norm_fast(x, weight, eps):
441
- # return LigerRMSNormFunction.apply(
442
- # x,
443
- # weight,
444
- # eps,
445
- # 0.0,
446
- # "gemma",
447
- # True,
448
- # )
449
-
450
- @staticmethod
451
- def rms_norm(x, weight, eps):
452
- x_dtype = x.dtype
453
- x = x.float()
454
- rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
455
- return (x * rrms).to(dtype=x_dtype) * weight
456
-
457
- def forward(self, x: Tensor):
458
- # return self.rms_norm_fast(x, self.scale, 1e-6)
459
- return self.rms_norm(x, self.scale, 1e-6)
460
-
461
-
462
- class QKNorm(torch.nn.Module):
463
- def __init__(self, dim: int):
464
- super().__init__()
465
- self.query_norm = RMSNorm(dim)
466
- self.key_norm = RMSNorm(dim)
467
-
468
- def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
469
- q = self.query_norm(q)
470
- k = self.key_norm(k)
471
- return q.to(v), k.to(v)
472
-
473
-
474
- class Modulation(nn.Module):
475
- def __init__(self, dim: int, double: bool):
476
- super().__init__()
477
- self.is_double = double
478
- self.multiplier = 6 if double else 3
479
- self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
480
-
481
- def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
482
- out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
483
- self.multiplier, dim=-1
484
- )
485
-
486
- return (
487
- ModulationOut(*out[:3]),
488
- ModulationOut(*out[3:]) if self.is_double else None,
489
- )
490
-
491
-
492
- class DoubleStreamBlock(nn.Module):
493
- def __init__(
494
- self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
495
- ):
496
- super().__init__()
497
-
498
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
499
- self.num_heads = num_heads
500
- self.hidden_size = hidden_size
501
- self.img_mod = Modulation(hidden_size, double=True)
502
- self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
503
- self.img_attn = SelfAttention(
504
- dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
505
- )
506
-
507
- self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
508
- self.img_mlp = nn.Sequential(
509
- nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
510
- nn.GELU(approximate="tanh"),
511
- nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
512
- )
513
-
514
- self.txt_mod = Modulation(hidden_size, double=True)
515
- self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
516
- self.txt_attn = SelfAttention(
517
- dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
518
- )
519
-
520
- self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
521
- self.txt_mlp = nn.Sequential(
522
- nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
523
- nn.GELU(approximate="tanh"),
524
- nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
525
- )
526
-
527
- def forward(
528
- self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
529
- ) -> tuple[Tensor, Tensor]:
530
- img_mod1, img_mod2 = self.img_mod(vec)
531
- txt_mod1, txt_mod2 = self.txt_mod(vec)
532
-
533
- # prepare image for attention
534
- img_modulated = self.img_norm1(img)
535
- img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
536
- img_qkv = self.img_attn.qkv(img_modulated)
537
- img_q, img_k, img_v = rearrange(
538
- img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
539
- )
540
- img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
541
-
542
- # prepare txt for attention
543
- txt_modulated = self.txt_norm1(txt)
544
- txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
545
- txt_qkv = self.txt_attn.qkv(txt_modulated)
546
- txt_q, txt_k, txt_v = rearrange(
547
- txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
548
- )
549
- txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
550
-
551
- # run actual attention
552
- q = torch.cat((txt_q, img_q), dim=1)
553
- k = torch.cat((txt_k, img_k), dim=1)
554
- v = torch.cat((txt_v, img_v), dim=1)
555
-
556
- attn = attention_after_rope(q, k, v, pe=pe)
557
- txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
558
-
559
- # calculate the img bloks
560
- img = img + img_mod1.gate * self.img_attn.proj(img_attn)
561
- img_mlp = self.img_mlp(
562
- (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
563
- )
564
- img = scale_add_residual(img_mlp, img_mod2.gate, img)
565
-
566
- # calculate the txt bloks
567
- txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
568
- txt_mlp = self.txt_mlp(
569
- (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
570
- )
571
- txt = scale_add_residual(txt_mlp, txt_mod2.gate, txt)
572
- return img, txt
573
-
574
-
575
- class SingleStreamBlock(nn.Module):
576
- """
577
- A DiT block with parallel linear layers as described in
578
- https://arxiv.org/abs/2302.05442 and adapted modulation interface.
579
- """
580
-
581
- def __init__(
582
- self,
583
- hidden_size: int,
584
- num_heads: int,
585
- mlp_ratio: float = 4.0,
586
- qk_scale: float | None = None,
587
- ):
588
- super().__init__()
589
- self.hidden_dim = hidden_size
590
- self.num_heads = num_heads
591
- head_dim = hidden_size // num_heads
592
- self.scale = qk_scale or head_dim**-0.5
593
-
594
- self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
595
- # qkv and mlp_in
596
- self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
597
- # proj and mlp_out
598
- self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
599
-
600
- self.norm = QKNorm(head_dim)
601
-
602
- self.hidden_size = hidden_size
603
- self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
604
-
605
- self.mlp_act = nn.GELU(approximate="tanh")
606
- self.modulation = Modulation(hidden_size, double=False)
607
-
608
- def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
609
- mod, _ = self.modulation(vec)
610
- x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
611
- qkv, mlp = torch.split(
612
- self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
613
- )
614
-
615
- q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
616
- q, k = self.norm(q, k, v)
617
-
618
- # compute attention
619
- attn = attention_after_rope(q, k, v, pe=pe)
620
- # compute activation in mlp stream, cat again and run second linear layer
621
- output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
622
- return scale_add_residual(output, mod.gate, x)
623
-
624
-
625
- class LastLayer(nn.Module):
626
- def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
627
- super().__init__()
628
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
629
- self.linear = nn.Linear(
630
- hidden_size, patch_size * patch_size * out_channels, bias=True
631
- )
632
- self.adaLN_modulation = nn.Sequential(
633
- nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
634
- )
635
-
636
- def forward(self, x: Tensor, vec: Tensor) -> Tensor:
637
- shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
638
- x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
639
- x = self.linear(x)
640
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
leather.jpg DELETED

Git LFS Details

  • SHA256: efa1eab6d7fa83b2bb39631b194012cf01cca24356b624f32e0fd05346af3ec2
  • Pointer size: 131 Bytes
  • Size of remote file: 250 kB
meme.jpg DELETED
Binary file (49.8 kB)
 
model_edit.cpython-310.pyc DELETED
Binary file (4.21 kB)
 
model_edit.py DELETED
@@ -1,143 +0,0 @@
1
- import math
2
- from dataclasses import dataclass
3
-
4
- import numpy as np
5
- import torch
6
- from torch import Tensor, nn
7
-
8
- from .connector_edit import Qwen2Connector
9
- from .layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock
10
-
11
-
12
- @dataclass
13
- class Step1XParams:
14
- in_channels: int
15
- out_channels: int
16
- vec_in_dim: int
17
- context_in_dim: int
18
- hidden_size: int
19
- mlp_ratio: float
20
- num_heads: int
21
- depth: int
22
- depth_single_blocks: int
23
- axes_dim: list[int]
24
- theta: int
25
- qkv_bias: bool
26
-
27
-
28
- class Step1XEdit(nn.Module):
29
- """
30
- Transformer model for flow matching on sequences.
31
- """
32
-
33
- def __init__(self, params: Step1XParams):
34
- super().__init__()
35
-
36
- self.params = params
37
- self.in_channels = params.in_channels
38
- self.out_channels = params.out_channels
39
- if params.hidden_size % params.num_heads != 0:
40
- raise ValueError(
41
- f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
42
- )
43
- pe_dim = params.hidden_size // params.num_heads
44
- if sum(params.axes_dim) != pe_dim:
45
- raise ValueError(
46
- f"Got {params.axes_dim} but expected positional dim {pe_dim}"
47
- )
48
- self.hidden_size = params.hidden_size
49
- self.num_heads = params.num_heads
50
- self.pe_embedder = EmbedND(
51
- dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
52
- )
53
- self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
54
- self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
55
- self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
56
- self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
57
-
58
- self.double_blocks = nn.ModuleList(
59
- [
60
- DoubleStreamBlock(
61
- self.hidden_size,
62
- self.num_heads,
63
- mlp_ratio=params.mlp_ratio,
64
- qkv_bias=params.qkv_bias,
65
- )
66
- for _ in range(params.depth)
67
- ]
68
- )
69
-
70
- self.single_blocks = nn.ModuleList(
71
- [
72
- SingleStreamBlock(
73
- self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
74
- )
75
- for _ in range(params.depth_single_blocks)
76
- ]
77
- )
78
-
79
- self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
80
-
81
- self.connector = Qwen2Connector()
82
-
83
- @staticmethod
84
- def timestep_embedding(
85
- t: Tensor, dim, max_period=10000, time_factor: float = 1000.0
86
- ):
87
- """
88
- Create sinusoidal timestep embeddings.
89
- :param t: a 1-D Tensor of N indices, one per batch element.
90
- These may be fractional.
91
- :param dim: the dimension of the output.
92
- :param max_period: controls the minimum frequency of the embeddings.
93
- :return: an (N, D) Tensor of positional embeddings.
94
- """
95
- t = time_factor * t
96
- half = dim // 2
97
- freqs = torch.exp(
98
- -math.log(max_period)
99
- * torch.arange(start=0, end=half, dtype=torch.float32)
100
- / half
101
- ).to(t.device)
102
-
103
- args = t[:, None].float() * freqs[None]
104
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
105
- if dim % 2:
106
- embedding = torch.cat(
107
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
108
- )
109
- if torch.is_floating_point(t):
110
- embedding = embedding.to(t)
111
- return embedding
112
-
113
- def forward(
114
- self,
115
- img: Tensor,
116
- img_ids: Tensor,
117
- txt: Tensor,
118
- txt_ids: Tensor,
119
- timesteps: Tensor,
120
- y: Tensor,
121
- ) -> Tensor:
122
- if img.ndim != 3 or txt.ndim != 3:
123
- raise ValueError("Input img and txt tensors must have 3 dimensions.")
124
-
125
- img = self.img_in(img)
126
- vec = self.time_in(self.timestep_embedding(timesteps, 256))
127
-
128
- vec = vec + self.vector_in(y)
129
- txt = self.txt_in(txt)
130
-
131
- ids = torch.cat((txt_ids, img_ids), dim=1)
132
- pe = self.pe_embedder(ids)
133
-
134
- for block in self.double_blocks:
135
- img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
136
-
137
- img = torch.cat((txt, img), 1)
138
- for block in self.single_blocks:
139
- img = block(img, vec=vec, pe=pe)
140
- img = img[:, txt.shape[1] :, ...]
141
-
142
- img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
143
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
celeb_meme.jpg → modules.zip RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4fb2ccab4218dba753781d65e8f5933f8ab7613543b59a7b4512a6654fe55a4f
3
- size 266588
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c432d89999f0ae531c09c6ccf1d4a69bf5c2bb878f23411fafdf64b7370c8afe
3
+ size 45293
modules/__init__.py DELETED
File without changes
modules/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (128 Bytes)
 
modules/__pycache__/attention.cpython-310.pyc DELETED
Binary file (3.13 kB)
 
modules/__pycache__/autoencoder.cpython-310.pyc DELETED
Binary file (8.78 kB)
 
modules/__pycache__/conditioner.cpython-310.pyc DELETED
Binary file (4.94 kB)
 
modules/__pycache__/connector_edit.cpython-310.pyc DELETED
Binary file (11.8 kB)
 
modules/__pycache__/layers.cpython-310.pyc DELETED
Binary file (19.1 kB)
 
modules/__pycache__/model_edit.cpython-310.pyc DELETED
Binary file (4.21 kB)
 
modules/attention.py DELETED
@@ -1,133 +0,0 @@
1
- import math
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
-
7
- try:
8
- import flash_attn
9
- from flash_attn.flash_attn_interface import (
10
- _flash_attn_forward,
11
- flash_attn_func,
12
- flash_attn_varlen_func,
13
- )
14
- except ImportError:
15
- flash_attn = None
16
- flash_attn_varlen_func = None
17
- _flash_attn_forward = None
18
- flash_attn_func = None
19
-
20
- MEMORY_LAYOUT = {
21
- # flash模式:
22
- # 预处理: 输入 [batch_size, seq_len, num_heads, head_dim]
23
- # 后处理: 保持形状不变
24
- "flash": (
25
- lambda x: x, # 保持形状
26
- lambda x: x, # 保持形状
27
- ),
28
- # torch/vanilla模式:
29
- # 预处理: 交换序列和注意力头的维度 [B,S,A,D] -> [B,A,S,D]
30
- # 后处理: 交换回原始维度 [B,A,S,D] -> [B,S,A,D]
31
- "torch": (
32
- lambda x: x.transpose(1, 2), # (B,S,A,D) -> (B,A,S,D)
33
- lambda x: x.transpose(1, 2), # (B,A,S,D) -> (B,S,A,D)
34
- ),
35
- "vanilla": (
36
- lambda x: x.transpose(1, 2),
37
- lambda x: x.transpose(1, 2),
38
- ),
39
- }
40
-
41
-
42
- def attention(
43
- q,
44
- k,
45
- v,
46
- mode="torch",
47
- drop_rate=0,
48
- attn_mask=None,
49
- causal=False,
50
- ):
51
- """
52
- 执行QKV自注意力计算
53
-
54
- Args:
55
- q (torch.Tensor): 查询张量,形状 [batch_size, seq_len, num_heads, head_dim]
56
- k (torch.Tensor): 键张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
57
- v (torch.Tensor): 值张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
58
- mode (str): 注意力模式,可选 'flash', 'torch', 'vanilla'
59
- drop_rate (float): 注意力矩阵的dropout概率
60
- attn_mask (torch.Tensor): 注意力掩码,形状根据模式不同而变化
61
- causal (bool): 是否使用因果注意力(仅关注前面位置)
62
-
63
- Returns:
64
- torch.Tensor: 注意力输出,形状 [batch_size, seq_len, num_heads * head_dim]
65
- """
66
- # 获取预处理和后处理函数
67
- pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
68
-
69
- # 应用预处理变换
70
- q = pre_attn_layout(q) # 形状根据模式变化
71
- k = pre_attn_layout(k)
72
- v = pre_attn_layout(v)
73
-
74
- if mode == "torch":
75
- # 使用PyTorch原生的scaled_dot_product_attention
76
- if attn_mask is not None and attn_mask.dtype != torch.bool:
77
- attn_mask = attn_mask.to(q.dtype)
78
- x = F.scaled_dot_product_attention(
79
- q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
80
- )
81
- elif mode == "flash":
82
- assert flash_attn_func is not None, "flash_attn_func未定义"
83
- assert attn_mask is None, "不支持的注意力掩码"
84
- x: torch.Tensor = flash_attn_func(
85
- q, k, v, dropout_p=drop_rate, causal=causal, softmax_scale=None
86
- ) # type: ignore
87
- elif mode == "vanilla":
88
- # 手动实现注意力机制
89
- scale_factor = 1 / math.sqrt(q.size(-1)) # 缩放因子 1/sqrt(d_k)
90
-
91
- b, a, s, _ = q.shape # 获取形状参数
92
- s1 = k.size(2) # 键值序列长度
93
-
94
- # 初始化注意力偏置
95
- attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
96
-
97
- # 处理因果掩码
98
- if causal:
99
- assert attn_mask is None, "因果掩码和注意力掩码不能同时使用"
100
- # 生成下三角因果掩码
101
- temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
102
- diagonal=0
103
- )
104
- attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
105
- attn_bias = attn_bias.to(q.dtype)
106
-
107
- # 处理自定义注意力掩码
108
- if attn_mask is not None:
109
- if attn_mask.dtype == torch.bool:
110
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
111
- else:
112
- attn_bias += attn_mask # 允许类似ALiBi的位置偏置
113
-
114
- # 计算注意力矩阵
115
- attn = (q @ k.transpose(-2, -1)) * scale_factor # [B,A,S,S1]
116
- attn += attn_bias
117
-
118
- # softmax和dropout
119
- attn = attn.softmax(dim=-1)
120
- attn = torch.dropout(attn, p=drop_rate, train=True)
121
-
122
- # 计算输出
123
- x = attn @ v # [B,A,S,D]
124
- else:
125
- raise NotImplementedError(f"不支持的注意力模式: {mode}")
126
-
127
- # 应用后处理变换
128
- x = post_attn_layout(x) # 恢复原始维度顺序
129
-
130
- # 合并注意力头维度
131
- b, s, a, d = x.shape
132
- out = x.reshape(b, s, -1) # [B,S,A*D]
133
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/autoencoder.py DELETED
@@ -1,326 +0,0 @@
1
- # Modified from Flux
2
- #
3
- # Copyright 2024 Black Forest Labs
4
-
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
-
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
-
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- #
17
- # This source code is licensed under the license found in the
18
- # LICENSE file in the root directory of this source tree.
19
- import torch
20
- from einops import rearrange
21
- from torch import Tensor, nn
22
-
23
-
24
- def swish(x: Tensor) -> Tensor:
25
- return x * torch.sigmoid(x)
26
-
27
-
28
- class AttnBlock(nn.Module):
29
- def __init__(self, in_channels: int):
30
- super().__init__()
31
- self.in_channels = in_channels
32
-
33
- self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
34
-
35
- self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
- self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
37
- self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
38
- self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
39
-
40
- def attention(self, h_: Tensor) -> Tensor:
41
- h_ = self.norm(h_)
42
- q = self.q(h_)
43
- k = self.k(h_)
44
- v = self.v(h_)
45
-
46
- b, c, h, w = q.shape
47
- q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
48
- k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
49
- v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
50
- h_ = nn.functional.scaled_dot_product_attention(q, k, v)
51
-
52
- return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
53
-
54
- def forward(self, x: Tensor) -> Tensor:
55
- return x + self.proj_out(self.attention(x))
56
-
57
-
58
- class ResnetBlock(nn.Module):
59
- def __init__(self, in_channels: int, out_channels: int):
60
- super().__init__()
61
- self.in_channels = in_channels
62
- out_channels = in_channels if out_channels is None else out_channels
63
- self.out_channels = out_channels
64
-
65
- self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
66
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
67
- self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
68
- self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
69
- if self.in_channels != self.out_channels:
70
- self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
71
-
72
- def forward(self, x):
73
- h = x
74
- h = self.norm1(h)
75
- h = swish(h)
76
- h = self.conv1(h)
77
-
78
- h = self.norm2(h)
79
- h = swish(h)
80
- h = self.conv2(h)
81
-
82
- if self.in_channels != self.out_channels:
83
- x = self.nin_shortcut(x)
84
-
85
- return x + h
86
-
87
-
88
- class Downsample(nn.Module):
89
- def __init__(self, in_channels: int):
90
- super().__init__()
91
- # no asymmetric padding in torch conv, must do it ourselves
92
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
93
-
94
- def forward(self, x: Tensor):
95
- pad = (0, 1, 0, 1)
96
- x = nn.functional.pad(x, pad, mode="constant", value=0)
97
- x = self.conv(x)
98
- return x
99
-
100
-
101
- class Upsample(nn.Module):
102
- def __init__(self, in_channels: int):
103
- super().__init__()
104
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
105
-
106
- def forward(self, x: Tensor):
107
- x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
108
- x = self.conv(x)
109
- return x
110
-
111
-
112
- class Encoder(nn.Module):
113
- def __init__(
114
- self,
115
- resolution: int,
116
- in_channels: int,
117
- ch: int,
118
- ch_mult: list[int],
119
- num_res_blocks: int,
120
- z_channels: int,
121
- ):
122
- super().__init__()
123
- self.ch = ch
124
- self.num_resolutions = len(ch_mult)
125
- self.num_res_blocks = num_res_blocks
126
- self.resolution = resolution
127
- self.in_channels = in_channels
128
- # downsampling
129
- self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
130
-
131
- curr_res = resolution
132
- in_ch_mult = (1, *tuple(ch_mult))
133
- self.in_ch_mult = in_ch_mult
134
- self.down = nn.ModuleList()
135
- block_in = self.ch
136
- for i_level in range(self.num_resolutions):
137
- block = nn.ModuleList()
138
- attn = nn.ModuleList()
139
- block_in = ch * in_ch_mult[i_level]
140
- block_out = ch * ch_mult[i_level]
141
- for _ in range(self.num_res_blocks):
142
- block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
143
- block_in = block_out
144
- down = nn.Module()
145
- down.block = block
146
- down.attn = attn
147
- if i_level != self.num_resolutions - 1:
148
- down.downsample = Downsample(block_in)
149
- curr_res = curr_res // 2
150
- self.down.append(down)
151
-
152
- # middle
153
- self.mid = nn.Module()
154
- self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
155
- self.mid.attn_1 = AttnBlock(block_in)
156
- self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
157
-
158
- # end
159
- self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
160
- self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
161
-
162
- def forward(self, x: Tensor) -> Tensor:
163
- # downsampling
164
- hs = [self.conv_in(x)]
165
- for i_level in range(self.num_resolutions):
166
- for i_block in range(self.num_res_blocks):
167
- h = self.down[i_level].block[i_block](hs[-1])
168
- if len(self.down[i_level].attn) > 0:
169
- h = self.down[i_level].attn[i_block](h)
170
- hs.append(h)
171
- if i_level != self.num_resolutions - 1:
172
- hs.append(self.down[i_level].downsample(hs[-1]))
173
-
174
- # middle
175
- h = hs[-1]
176
- h = self.mid.block_1(h)
177
- h = self.mid.attn_1(h)
178
- h = self.mid.block_2(h)
179
- # end
180
- h = self.norm_out(h)
181
- h = swish(h)
182
- h = self.conv_out(h)
183
- return h
184
-
185
-
186
- class Decoder(nn.Module):
187
- def __init__(
188
- self,
189
- ch: int,
190
- out_ch: int,
191
- ch_mult: list[int],
192
- num_res_blocks: int,
193
- in_channels: int,
194
- resolution: int,
195
- z_channels: int,
196
- ):
197
- super().__init__()
198
- self.ch = ch
199
- self.num_resolutions = len(ch_mult)
200
- self.num_res_blocks = num_res_blocks
201
- self.resolution = resolution
202
- self.in_channels = in_channels
203
- self.ffactor = 2 ** (self.num_resolutions - 1)
204
-
205
- # compute in_ch_mult, block_in and curr_res at lowest res
206
- block_in = ch * ch_mult[self.num_resolutions - 1]
207
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
208
- self.z_shape = (1, z_channels, curr_res, curr_res)
209
-
210
- # z to block_in
211
- self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
212
-
213
- # middle
214
- self.mid = nn.Module()
215
- self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
216
- self.mid.attn_1 = AttnBlock(block_in)
217
- self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
218
-
219
- # upsampling
220
- self.up = nn.ModuleList()
221
- for i_level in reversed(range(self.num_resolutions)):
222
- block = nn.ModuleList()
223
- attn = nn.ModuleList()
224
- block_out = ch * ch_mult[i_level]
225
- for _ in range(self.num_res_blocks + 1):
226
- block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
227
- block_in = block_out
228
- up = nn.Module()
229
- up.block = block
230
- up.attn = attn
231
- if i_level != 0:
232
- up.upsample = Upsample(block_in)
233
- curr_res = curr_res * 2
234
- self.up.insert(0, up) # prepend to get consistent order
235
-
236
- # end
237
- self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
238
- self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
239
-
240
- def forward(self, z: Tensor) -> Tensor:
241
- # z to block_in
242
- h = self.conv_in(z)
243
-
244
- # middle
245
- h = self.mid.block_1(h)
246
- h = self.mid.attn_1(h)
247
- h = self.mid.block_2(h)
248
-
249
- # upsampling
250
- for i_level in reversed(range(self.num_resolutions)):
251
- for i_block in range(self.num_res_blocks + 1):
252
- h = self.up[i_level].block[i_block](h)
253
- if len(self.up[i_level].attn) > 0:
254
- h = self.up[i_level].attn[i_block](h)
255
- if i_level != 0:
256
- h = self.up[i_level].upsample(h)
257
-
258
- # end
259
- h = self.norm_out(h)
260
- h = swish(h)
261
- h = self.conv_out(h)
262
- return h
263
-
264
-
265
- class DiagonalGaussian(nn.Module):
266
- def __init__(self, sample: bool = True, chunk_dim: int = 1):
267
- super().__init__()
268
- self.sample = sample
269
- self.chunk_dim = chunk_dim
270
-
271
- def forward(self, z: Tensor) -> Tensor:
272
- mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
273
- if self.sample:
274
- std = torch.exp(0.5 * logvar)
275
- return mean + std * torch.randn_like(mean)
276
- else:
277
- return mean
278
-
279
-
280
- class AutoEncoder(nn.Module):
281
- def __init__(
282
- self,
283
- resolution: int,
284
- in_channels: int,
285
- ch: int,
286
- out_ch: int,
287
- ch_mult: list[int],
288
- num_res_blocks: int,
289
- z_channels: int,
290
- scale_factor: float,
291
- shift_factor: float,
292
- ):
293
- super().__init__()
294
- self.encoder = Encoder(
295
- resolution=resolution,
296
- in_channels=in_channels,
297
- ch=ch,
298
- ch_mult=ch_mult,
299
- num_res_blocks=num_res_blocks,
300
- z_channels=z_channels,
301
- )
302
- self.decoder = Decoder(
303
- resolution=resolution,
304
- in_channels=in_channels,
305
- ch=ch,
306
- out_ch=out_ch,
307
- ch_mult=ch_mult,
308
- num_res_blocks=num_res_blocks,
309
- z_channels=z_channels,
310
- )
311
- self.reg = DiagonalGaussian()
312
-
313
- self.scale_factor = scale_factor
314
- self.shift_factor = shift_factor
315
-
316
- def encode(self, x: Tensor) -> Tensor:
317
- z = self.reg(self.encoder(x))
318
- z = self.scale_factor * (z - self.shift_factor)
319
- return z
320
-
321
- def decode(self, z: Tensor) -> Tensor:
322
- z = z / self.scale_factor + self.shift_factor
323
- return self.decoder(z)
324
-
325
- def forward(self, x: Tensor) -> Tensor:
326
- return self.decode(self.encode(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/conditioner.py DELETED
@@ -1,216 +0,0 @@
1
- import torch
2
- from qwen_vl_utils import process_vision_info
3
- from transformers import (
4
- AutoProcessor,
5
- Qwen2VLForConditionalGeneration,
6
- Qwen2_5_VLForConditionalGeneration,
7
- )
8
- from torchvision.transforms import ToPILImage
9
-
10
- to_pil = ToPILImage()
11
-
12
- Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
13
- - If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
14
- - If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
15
- Here are examples of how to transform or refine prompts:
16
- - User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
17
- - User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
18
- Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
19
- User Prompt:'''
20
-
21
-
22
- def split_string(s):
23
- # 将中文引号替换为英文引号
24
- s = s.replace("“", '"').replace("”", '"') # use english quotes
25
- result = []
26
- # 标记是否在引号内
27
- in_quotes = False
28
- temp = ""
29
-
30
- # 遍历字符串中的每个字符及其索引
31
- for idx, char in enumerate(s):
32
- # 如果字符是引号且索引大于 155
33
- if char == '"' and idx > 155:
34
- # 将引号添加到临时字符串
35
- temp += char
36
- # 如果不在引号内
37
- if not in_quotes:
38
- # 将临时字符串添加到结果列表
39
- result.append(temp)
40
- # 清空临时字符串
41
- temp = ""
42
-
43
- # 切换引号状态
44
- in_quotes = not in_quotes
45
- continue
46
- # 如果在引号内
47
- if in_quotes:
48
- # 如果字符是空格
49
- if char.isspace():
50
- pass # have space token
51
-
52
- # 将字符用中文引号包裹后添加到结果列表
53
- result.append("“" + char + "”")
54
- else:
55
- # 将字符添加到临时字符串
56
- temp += char
57
-
58
- # 如果临时字符串不为空
59
- if temp:
60
- # 将临时字符串添加到结果列表
61
- result.append(temp)
62
-
63
- return result
64
-
65
-
66
- class Qwen25VL_7b_Embedder(torch.nn.Module):
67
- def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"):
68
- super(Qwen25VL_7b_Embedder, self).__init__()
69
- self.max_length = max_length
70
- self.dtype = dtype
71
- self.device = device
72
-
73
- self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
74
- model_path,
75
- torch_dtype=dtype,
76
- attn_implementation="eager",
77
- ).to(torch.cuda.current_device())
78
-
79
- self.model.requires_grad_(False)
80
- self.processor = AutoProcessor.from_pretrained(
81
- model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28
82
- )
83
-
84
- self.prefix = Qwen25VL_7b_PREFIX
85
-
86
- def forward(self, caption, ref_images):
87
- text_list = caption
88
- embs = torch.zeros(
89
- len(text_list),
90
- self.max_length,
91
- self.model.config.hidden_size,
92
- dtype=torch.bfloat16,
93
- device=torch.cuda.current_device(),
94
- )
95
- hidden_states = torch.zeros(
96
- len(text_list),
97
- self.max_length,
98
- self.model.config.hidden_size,
99
- dtype=torch.bfloat16,
100
- device=torch.cuda.current_device(),
101
- )
102
- masks = torch.zeros(
103
- len(text_list),
104
- self.max_length,
105
- dtype=torch.long,
106
- device=torch.cuda.current_device(),
107
- )
108
- input_ids_list = []
109
- attention_mask_list = []
110
- emb_list = []
111
-
112
- def split_string(s):
113
- s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes
114
- result = []
115
- in_quotes = False
116
- temp = ""
117
-
118
- for idx,char in enumerate(s):
119
- if char == '"' and idx>155:
120
- temp += char
121
- if not in_quotes:
122
- result.append(temp)
123
- temp = ""
124
-
125
- in_quotes = not in_quotes
126
- continue
127
- if in_quotes:
128
- if char.isspace():
129
- pass # have space token
130
-
131
- result.append("“" + char + "”")
132
- else:
133
- temp += char
134
-
135
- if temp:
136
- result.append(temp)
137
-
138
- return result
139
-
140
- for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):
141
-
142
- messages = [{"role": "user", "content": []}]
143
-
144
- messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"})
145
-
146
- messages[0]["content"].append({"type": "image", "image": to_pil(imgs)})
147
-
148
- # 再添加 text
149
- messages[0]["content"].append({"type": "text", "text": f"{txt}"})
150
-
151
- # Preparation for inference
152
- text = self.processor.apply_chat_template(
153
- messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
154
- )
155
-
156
- image_inputs, video_inputs = process_vision_info(messages)
157
-
158
- inputs = self.processor(
159
- text=[text],
160
- images=image_inputs,
161
- padding=True,
162
- return_tensors="pt",
163
- )
164
-
165
- old_inputs_ids = inputs.input_ids
166
- text_split_list = split_string(text)
167
-
168
- token_list = []
169
- for text_each in text_split_list:
170
- txt_inputs = self.processor(
171
- text=text_each,
172
- images=None,
173
- videos=None,
174
- padding=True,
175
- return_tensors="pt",
176
- )
177
- token_each = txt_inputs.input_ids
178
- if token_each[0][0] == 2073 and token_each[0][-1] == 854:
179
- token_each = token_each[:, 1:-1]
180
- token_list.append(token_each)
181
- else:
182
- token_list.append(token_each)
183
-
184
- new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
185
-
186
- new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
187
-
188
- idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]
189
- idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]
190
- inputs.input_ids = (
191
- torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
192
- .unsqueeze(0)
193
- .to("cuda")
194
- )
195
- inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
196
- outputs = self.model(
197
- input_ids=inputs.input_ids,
198
- attention_mask=inputs.attention_mask,
199
- pixel_values=inputs.pixel_values.to("cuda"),
200
- image_grid_thw=inputs.image_grid_thw.to("cuda"),
201
- output_hidden_states=True,
202
- )
203
-
204
- emb = outputs["hidden_states"][-1]
205
-
206
- embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
207
- : self.max_length
208
- ]
209
-
210
- masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
211
- (min(self.max_length, emb.shape[1] - 217)),
212
- dtype=torch.long,
213
- device=torch.cuda.current_device(),
214
- )
215
-
216
- return embs, masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/connector_edit.py DELETED
@@ -1,486 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- import torch.nn
5
- from einops import rearrange
6
- from torch import nn
7
-
8
- from .layers import MLP, TextProjection, TimestepEmbedder, apply_gate, attention
9
-
10
-
11
- class RMSNorm(nn.Module):
12
- def __init__(
13
- self,
14
- dim: int,
15
- elementwise_affine=True,
16
- eps: float = 1e-6,
17
- device=None,
18
- dtype=None,
19
- ):
20
- """
21
- Initialize the RMSNorm normalization layer.
22
-
23
- Args:
24
- dim (int): The dimension of the input tensor.
25
- eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
26
-
27
- Attributes:
28
- eps (float): A small value added to the denominator for numerical stability.
29
- weight (nn.Parameter): Learnable scaling parameter.
30
-
31
- """
32
- factory_kwargs = {"device": device, "dtype": dtype}
33
- super().__init__()
34
- self.eps = eps
35
- if elementwise_affine:
36
- self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
37
-
38
- def _norm(self, x):
39
- """
40
- Apply the RMSNorm normalization to the input tensor.
41
-
42
- Args:
43
- x (torch.Tensor): The input tensor.
44
-
45
- Returns:
46
- torch.Tensor: The normalized tensor.
47
-
48
- """
49
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
50
-
51
- def forward(self, x):
52
- """
53
- Forward pass through the RMSNorm layer.
54
-
55
- Args:
56
- x (torch.Tensor): The input tensor.
57
-
58
- Returns:
59
- torch.Tensor: The output tensor after applying RMSNorm.
60
-
61
- """
62
- output = self._norm(x.float()).type_as(x)
63
- if hasattr(self, "weight"):
64
- output = output * self.weight
65
- return output
66
-
67
-
68
- def get_norm_layer(norm_layer):
69
- """
70
- Get the normalization layer.
71
-
72
- Args:
73
- norm_layer (str): The type of normalization layer.
74
-
75
- Returns:
76
- norm_layer (nn.Module): The normalization layer.
77
- """
78
- if norm_layer == "layer":
79
- return nn.LayerNorm
80
- elif norm_layer == "rms":
81
- return RMSNorm
82
- else:
83
- raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
84
-
85
-
86
- def get_activation_layer(act_type):
87
- """get activation layer
88
-
89
- Args:
90
- act_type (str): the activation type
91
-
92
- Returns:
93
- torch.nn.functional: the activation layer
94
- """
95
- if act_type == "gelu":
96
- return lambda: nn.GELU()
97
- elif act_type == "gelu_tanh":
98
- return lambda: nn.GELU(approximate="tanh")
99
- elif act_type == "relu":
100
- return nn.ReLU
101
- elif act_type == "silu":
102
- return nn.SiLU
103
- else:
104
- raise ValueError(f"Unknown activation type: {act_type}")
105
-
106
- class IndividualTokenRefinerBlock(torch.nn.Module):
107
- def __init__(
108
- self,
109
- hidden_size,
110
- heads_num,
111
- mlp_width_ratio: str = 4.0,
112
- mlp_drop_rate: float = 0.0,
113
- act_type: str = "silu",
114
- qk_norm: bool = False,
115
- qk_norm_type: str = "layer",
116
- qkv_bias: bool = True,
117
- need_CA: bool = False,
118
- dtype: Optional[torch.dtype] = None,
119
- device: Optional[torch.device] = None,
120
- ):
121
- factory_kwargs = {"device": device, "dtype": dtype}
122
- super().__init__()
123
- self.need_CA = need_CA
124
- self.heads_num = heads_num
125
- head_dim = hidden_size // heads_num
126
- mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
127
-
128
- self.norm1 = nn.LayerNorm(
129
- hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
130
- )
131
- self.self_attn_qkv = nn.Linear(
132
- hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
133
- )
134
- qk_norm_layer = get_norm_layer(qk_norm_type)
135
- self.self_attn_q_norm = (
136
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
137
- if qk_norm
138
- else nn.Identity()
139
- )
140
- self.self_attn_k_norm = (
141
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
142
- if qk_norm
143
- else nn.Identity()
144
- )
145
- self.self_attn_proj = nn.Linear(
146
- hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
147
- )
148
-
149
- self.norm2 = nn.LayerNorm(
150
- hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
151
- )
152
- act_layer = get_activation_layer(act_type)
153
- self.mlp = MLP(
154
- in_channels=hidden_size,
155
- hidden_channels=mlp_hidden_dim,
156
- act_layer=act_layer,
157
- drop=mlp_drop_rate,
158
- **factory_kwargs,
159
- )
160
-
161
- self.adaLN_modulation = nn.Sequential(
162
- act_layer(),
163
- nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
164
- )
165
-
166
- if self.need_CA:
167
- self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size,
168
- heads_num=heads_num,
169
- mlp_width_ratio=mlp_width_ratio,
170
- mlp_drop_rate=mlp_drop_rate,
171
- act_type=act_type,
172
- qk_norm=qk_norm,
173
- qk_norm_type=qk_norm_type,
174
- qkv_bias=qkv_bias,
175
- **factory_kwargs,)
176
- # Zero-initialize the modulation
177
- nn.init.zeros_(self.adaLN_modulation[1].weight)
178
- nn.init.zeros_(self.adaLN_modulation[1].bias)
179
-
180
- def forward(
181
- self,
182
- x: torch.Tensor,
183
- c: torch.Tensor, # timestep_aware_representations + context_aware_representations
184
- attn_mask: torch.Tensor = None,
185
- y: torch.Tensor = None,
186
- ):
187
- gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
188
-
189
- norm_x = self.norm1(x)
190
- qkv = self.self_attn_qkv(norm_x)
191
- q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
192
- # Apply QK-Norm if needed
193
- q = self.self_attn_q_norm(q).to(v)
194
- k = self.self_attn_k_norm(k).to(v)
195
-
196
- # Self-Attention
197
- attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
198
-
199
- x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
200
-
201
- if self.need_CA:
202
- x = self.cross_attnblock(x, c, attn_mask, y)
203
-
204
- # FFN Layer
205
- x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
206
-
207
- return x
208
-
209
-
210
-
211
-
212
- class CrossAttnBlock(torch.nn.Module):
213
- def __init__(
214
- self,
215
- hidden_size,
216
- heads_num,
217
- mlp_width_ratio: str = 4.0,
218
- mlp_drop_rate: float = 0.0,
219
- act_type: str = "silu",
220
- qk_norm: bool = False,
221
- qk_norm_type: str = "layer",
222
- qkv_bias: bool = True,
223
- dtype: Optional[torch.dtype] = None,
224
- device: Optional[torch.device] = None,
225
- ):
226
- factory_kwargs = {"device": device, "dtype": dtype}
227
- super().__init__()
228
- self.heads_num = heads_num
229
- head_dim = hidden_size // heads_num
230
-
231
- self.norm1 = nn.LayerNorm(
232
- hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
233
- )
234
- self.norm1_2 = nn.LayerNorm(
235
- hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
236
- )
237
- self.self_attn_q = nn.Linear(
238
- hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
239
- )
240
- self.self_attn_kv = nn.Linear(
241
- hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs
242
- )
243
- qk_norm_layer = get_norm_layer(qk_norm_type)
244
- self.self_attn_q_norm = (
245
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
246
- if qk_norm
247
- else nn.Identity()
248
- )
249
- self.self_attn_k_norm = (
250
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
251
- if qk_norm
252
- else nn.Identity()
253
- )
254
- self.self_attn_proj = nn.Linear(
255
- hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
256
- )
257
-
258
- self.norm2 = nn.LayerNorm(
259
- hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
260
- )
261
- act_layer = get_activation_layer(act_type)
262
-
263
- self.adaLN_modulation = nn.Sequential(
264
- act_layer(),
265
- nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
266
- )
267
- # Zero-initialize the modulation
268
- nn.init.zeros_(self.adaLN_modulation[1].weight)
269
- nn.init.zeros_(self.adaLN_modulation[1].bias)
270
-
271
- def forward(
272
- self,
273
- x: torch.Tensor,
274
- c: torch.Tensor, # timestep_aware_representations + context_aware_representations
275
- attn_mask: torch.Tensor = None,
276
- y: torch.Tensor=None,
277
-
278
- ):
279
- gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
280
-
281
- norm_x = self.norm1(x)
282
- norm_y = self.norm1_2(y)
283
- q = self.self_attn_q(norm_x)
284
- q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
285
- kv = self.self_attn_kv(norm_y)
286
- k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num)
287
- # Apply QK-Norm if needed
288
- q = self.self_attn_q_norm(q).to(v)
289
- k = self.self_attn_k_norm(k).to(v)
290
-
291
- # Self-Attention
292
- attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
293
-
294
- x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
295
-
296
- return x
297
-
298
-
299
-
300
- class IndividualTokenRefiner(torch.nn.Module):
301
- def __init__(
302
- self,
303
- hidden_size,
304
- heads_num,
305
- depth,
306
- mlp_width_ratio: float = 4.0,
307
- mlp_drop_rate: float = 0.0,
308
- act_type: str = "silu",
309
- qk_norm: bool = False,
310
- qk_norm_type: str = "layer",
311
- qkv_bias: bool = True,
312
- need_CA:bool=False,
313
- dtype: Optional[torch.dtype] = None,
314
- device: Optional[torch.device] = None,
315
- ):
316
-
317
- factory_kwargs = {"device": device, "dtype": dtype}
318
- super().__init__()
319
- self.need_CA = need_CA
320
- self.blocks = nn.ModuleList(
321
- [
322
- IndividualTokenRefinerBlock(
323
- hidden_size=hidden_size,
324
- heads_num=heads_num,
325
- mlp_width_ratio=mlp_width_ratio,
326
- mlp_drop_rate=mlp_drop_rate,
327
- act_type=act_type,
328
- qk_norm=qk_norm,
329
- qk_norm_type=qk_norm_type,
330
- qkv_bias=qkv_bias,
331
- need_CA=self.need_CA,
332
- **factory_kwargs,
333
- )
334
- for _ in range(depth)
335
- ]
336
- )
337
-
338
-
339
- def forward(
340
- self,
341
- x: torch.Tensor,
342
- c: torch.LongTensor,
343
- mask: Optional[torch.Tensor] = None,
344
- y:torch.Tensor=None,
345
- ):
346
- self_attn_mask = None
347
- if mask is not None:
348
- batch_size = mask.shape[0]
349
- seq_len = mask.shape[1]
350
- mask = mask.to(x.device)
351
- # batch_size x 1 x seq_len x seq_len
352
- self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
353
- 1, 1, seq_len, 1
354
- )
355
- # batch_size x 1 x seq_len x seq_len
356
- self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
357
- # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
358
- self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
359
- # avoids self-attention weight being NaN for padding tokens
360
- self_attn_mask[:, :, :, 0] = True
361
-
362
-
363
- for block in self.blocks:
364
- x = block(x, c, self_attn_mask,y)
365
-
366
- return x
367
-
368
-
369
- class SingleTokenRefiner(torch.nn.Module):
370
- """
371
- A single token refiner block for llm text embedding refine.
372
- """
373
- def __init__(
374
- self,
375
- in_channels,
376
- hidden_size,
377
- heads_num,
378
- depth,
379
- mlp_width_ratio: float = 4.0,
380
- mlp_drop_rate: float = 0.0,
381
- act_type: str = "silu",
382
- qk_norm: bool = False,
383
- qk_norm_type: str = "layer",
384
- qkv_bias: bool = True,
385
- need_CA:bool=False,
386
- attn_mode: str = "torch",
387
- dtype: Optional[torch.dtype] = None,
388
- device: Optional[torch.device] = None,
389
- ):
390
- factory_kwargs = {"device": device, "dtype": dtype}
391
- super().__init__()
392
- self.attn_mode = attn_mode
393
- self.need_CA = need_CA
394
- assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
395
-
396
- self.input_embedder = nn.Linear(
397
- in_channels, hidden_size, bias=True, **factory_kwargs
398
- )
399
- if self.need_CA:
400
- self.input_embedder_CA = nn.Linear(
401
- in_channels, hidden_size, bias=True, **factory_kwargs
402
- )
403
-
404
- act_layer = get_activation_layer(act_type)
405
- # Build timestep embedding layer
406
- self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
407
- # Build context embedding layer
408
- self.c_embedder = TextProjection(
409
- in_channels, hidden_size, act_layer, **factory_kwargs
410
- )
411
-
412
- self.individual_token_refiner = IndividualTokenRefiner(
413
- hidden_size=hidden_size,
414
- heads_num=heads_num,
415
- depth=depth,
416
- mlp_width_ratio=mlp_width_ratio,
417
- mlp_drop_rate=mlp_drop_rate,
418
- act_type=act_type,
419
- qk_norm=qk_norm,
420
- qk_norm_type=qk_norm_type,
421
- qkv_bias=qkv_bias,
422
- need_CA=need_CA,
423
- **factory_kwargs,
424
- )
425
-
426
- def forward(
427
- self,
428
- x: torch.Tensor,
429
- t: torch.LongTensor,
430
- mask: Optional[torch.LongTensor] = None,
431
- y: torch.LongTensor=None,
432
- ):
433
- timestep_aware_representations = self.t_embedder(t)
434
-
435
- if mask is None:
436
- context_aware_representations = x.mean(dim=1)
437
- else:
438
- mask_float = mask.unsqueeze(-1) # [b, s1, 1]
439
- context_aware_representations = (x * mask_float).sum(
440
- dim=1
441
- ) / mask_float.sum(dim=1)
442
- context_aware_representations = self.c_embedder(context_aware_representations)
443
- c = timestep_aware_representations + context_aware_representations
444
-
445
- x = self.input_embedder(x)
446
- if self.need_CA:
447
- y = self.input_embedder_CA(y)
448
- x = self.individual_token_refiner(x, c, mask, y)
449
- else:
450
- x = self.individual_token_refiner(x, c, mask)
451
-
452
- return x
453
-
454
-
455
-
456
- class Qwen2Connector(torch.nn.Module):
457
- def __init__(
458
- self,
459
- # biclip_dim=1024,
460
- in_channels=3584,
461
- hidden_size=4096,
462
- heads_num=32,
463
- depth=2,
464
- need_CA=False,
465
- device=None,
466
- dtype=torch.bfloat16,
467
- ):
468
- super().__init__()
469
- factory_kwargs = {"device": device, "dtype":dtype}
470
-
471
- self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs)
472
- self.global_proj_out=nn.Linear(in_channels,768)
473
-
474
- self.scale_factor = nn.Parameter(torch.zeros(1))
475
- with torch.no_grad():
476
- self.scale_factor.data += -(1 - 0.09)
477
-
478
- def forward(self, x,t,mask):
479
- mask_float = mask.unsqueeze(-1) # [b, s1, 1]
480
- x_mean = (x * mask_float).sum(
481
- dim=1
482
- ) / mask_float.sum(dim=1) * (1 + self.scale_factor)
483
-
484
- global_out=self.global_proj_out(x_mean)
485
- encoder_hidden_states = self.S(x,t,mask)
486
- return encoder_hidden_states,global_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/layers.py DELETED
@@ -1,640 +0,0 @@
1
- # Modified from Flux
2
- #
3
- # Copyright 2024 Black Forest Labs
4
-
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
-
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
-
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- #
17
- # This source code is licensed under the license found in the
18
- # LICENSE file in the root directory of this source tree.
19
-
20
- import math # noqa: I001
21
- from dataclasses import dataclass
22
- from functools import partial
23
-
24
- import torch
25
- import torch.nn.functional as F
26
- from einops import rearrange
27
- # from liger_kernel.ops.rms_norm import LigerRMSNormFunction
28
- from torch import Tensor, nn
29
-
30
-
31
- try:
32
- import flash_attn
33
- from flash_attn.flash_attn_interface import (
34
- _flash_attn_forward,
35
- flash_attn_varlen_func,
36
- )
37
- except ImportError:
38
- flash_attn = None
39
- flash_attn_varlen_func = None
40
- _flash_attn_forward = None
41
-
42
-
43
- MEMORY_LAYOUT = {
44
- "flash": (
45
- lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
46
- lambda x: x,
47
- ),
48
- "torch": (
49
- lambda x: x.transpose(1, 2),
50
- lambda x: x.transpose(1, 2),
51
- ),
52
- "vanilla": (
53
- lambda x: x.transpose(1, 2),
54
- lambda x: x.transpose(1, 2),
55
- ),
56
- }
57
-
58
-
59
- def attention(
60
- q,
61
- k,
62
- v,
63
- mode="torch",
64
- drop_rate=0,
65
- attn_mask=None,
66
- causal=False,
67
- cu_seqlens_q=None,
68
- cu_seqlens_kv=None,
69
- max_seqlen_q=None,
70
- max_seqlen_kv=None,
71
- batch_size=1,
72
- ):
73
- """
74
- Perform QKV self attention.
75
-
76
- Args:
77
- q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
78
- k (torch.Tensor): Key tensor with shape [b, s1, a, d]
79
- v (torch.Tensor): Value tensor with shape [b, s1, a, d]
80
- mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
81
- drop_rate (float): Dropout rate in attention map. (default: 0)
82
- attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
83
- (default: None)
84
- causal (bool): Whether to use causal attention. (default: False)
85
- cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
86
- used to index into q.
87
- cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
88
- used to index into kv.
89
- max_seqlen_q (int): The maximum sequence length in the batch of q.
90
- max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
91
-
92
- Returns:
93
- torch.Tensor: Output tensor after self attention with shape [b, s, ad]
94
- """
95
- pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
96
- q = pre_attn_layout(q)
97
- k = pre_attn_layout(k)
98
- v = pre_attn_layout(v)
99
-
100
- if mode == "torch":
101
- if attn_mask is not None and attn_mask.dtype != torch.bool:
102
- attn_mask = attn_mask.to(q.dtype)
103
- x = F.scaled_dot_product_attention(
104
- q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
105
- )
106
- elif mode == "flash":
107
- assert flash_attn_varlen_func is not None
108
- x: torch.Tensor = flash_attn_varlen_func(
109
- q,
110
- k,
111
- v,
112
- cu_seqlens_q,
113
- cu_seqlens_kv,
114
- max_seqlen_q,
115
- max_seqlen_kv,
116
- ) # type: ignore
117
- # x with shape [(bxs), a, d]
118
- x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # type: ignore # reshape x to [b, s, a, d]
119
- elif mode == "vanilla":
120
- scale_factor = 1 / math.sqrt(q.size(-1))
121
-
122
- b, a, s, _ = q.shape
123
- s1 = k.size(2)
124
- attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
125
- if causal:
126
- # Only applied to self attention
127
- assert attn_mask is None, (
128
- "Causal mask and attn_mask cannot be used together"
129
- )
130
- temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
131
- diagonal=0
132
- )
133
- attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
134
- attn_bias.to(q.dtype)
135
-
136
- if attn_mask is not None:
137
- if attn_mask.dtype == torch.bool:
138
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
139
- else:
140
- attn_bias += attn_mask
141
-
142
- # TODO: Maybe force q and k to be float32 to avoid numerical overflow
143
- attn = (q @ k.transpose(-2, -1)) * scale_factor
144
- attn += attn_bias
145
- attn = attn.softmax(dim=-1)
146
- attn = torch.dropout(attn, p=drop_rate, train=True)
147
- x = attn @ v
148
- else:
149
- raise NotImplementedError(f"Unsupported attention mode: {mode}")
150
-
151
- x = post_attn_layout(x)
152
- b, s, a, d = x.shape
153
- out = x.reshape(b, s, -1)
154
- return out
155
-
156
-
157
- def apply_gate(x, gate=None, tanh=False):
158
- """AI is creating summary for apply_gate
159
-
160
- Args:
161
- x (torch.Tensor): input tensor.
162
- gate (torch.Tensor, optional): gate tensor. Defaults to None.
163
- tanh (bool, optional): whether to use tanh function. Defaults to False.
164
-
165
- Returns:
166
- torch.Tensor: the output tensor after apply gate.
167
- """
168
- if gate is None:
169
- return x
170
- if tanh:
171
- return x * gate.unsqueeze(1).tanh()
172
- else:
173
- return x * gate.unsqueeze(1)
174
-
175
-
176
- class MLP(nn.Module):
177
- """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
178
-
179
- def __init__(
180
- self,
181
- in_channels,
182
- hidden_channels=None,
183
- out_features=None,
184
- act_layer=nn.GELU,
185
- norm_layer=None,
186
- bias=True,
187
- drop=0.0,
188
- use_conv=False,
189
- device=None,
190
- dtype=None,
191
- ):
192
- super().__init__()
193
- out_features = out_features or in_channels
194
- hidden_channels = hidden_channels or in_channels
195
- bias = (bias, bias)
196
- drop_probs = (drop, drop)
197
- linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
198
-
199
- self.fc1 = linear_layer(
200
- in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype
201
- )
202
- self.act = act_layer()
203
- self.drop1 = nn.Dropout(drop_probs[0])
204
- self.norm = (
205
- norm_layer(hidden_channels, device=device, dtype=dtype)
206
- if norm_layer is not None
207
- else nn.Identity()
208
- )
209
- self.fc2 = linear_layer(
210
- hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype
211
- )
212
- self.drop2 = nn.Dropout(drop_probs[1])
213
-
214
- def forward(self, x):
215
- x = self.fc1(x)
216
- x = self.act(x)
217
- x = self.drop1(x)
218
- x = self.norm(x)
219
- x = self.fc2(x)
220
- x = self.drop2(x)
221
- return x
222
-
223
-
224
- class TextProjection(nn.Module):
225
- """
226
- Projects text embeddings. Also handles dropout for classifier-free guidance.
227
-
228
- Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
229
- """
230
-
231
- def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
232
- factory_kwargs = {"dtype": dtype, "device": device}
233
- super().__init__()
234
- self.linear_1 = nn.Linear(
235
- in_features=in_channels,
236
- out_features=hidden_size,
237
- bias=True,
238
- **factory_kwargs,
239
- )
240
- self.act_1 = act_layer()
241
- self.linear_2 = nn.Linear(
242
- in_features=hidden_size,
243
- out_features=hidden_size,
244
- bias=True,
245
- **factory_kwargs,
246
- )
247
-
248
- def forward(self, caption):
249
- hidden_states = self.linear_1(caption)
250
- hidden_states = self.act_1(hidden_states)
251
- hidden_states = self.linear_2(hidden_states)
252
- return hidden_states
253
-
254
-
255
- class TimestepEmbedder(nn.Module):
256
- """
257
- Embeds scalar timesteps into vector representations.
258
- """
259
-
260
- def __init__(
261
- self,
262
- hidden_size,
263
- act_layer,
264
- frequency_embedding_size=256,
265
- max_period=10000,
266
- out_size=None,
267
- dtype=None,
268
- device=None,
269
- ):
270
- factory_kwargs = {"dtype": dtype, "device": device}
271
- super().__init__()
272
- self.frequency_embedding_size = frequency_embedding_size
273
- self.max_period = max_period
274
- if out_size is None:
275
- out_size = hidden_size
276
-
277
- self.mlp = nn.Sequential(
278
- nn.Linear(
279
- frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
280
- ),
281
- act_layer(),
282
- nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
283
- )
284
- nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore
285
- nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore
286
-
287
- @staticmethod
288
- def timestep_embedding(t, dim, max_period=10000):
289
- """
290
- Create sinusoidal timestep embeddings.
291
-
292
- Args:
293
- t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
294
- dim (int): the dimension of the output.
295
- max_period (int): controls the minimum frequency of the embeddings.
296
-
297
- Returns:
298
- embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
299
-
300
- .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
301
- """
302
- half = dim // 2
303
- freqs = torch.exp(
304
- -math.log(max_period)
305
- * torch.arange(start=0, end=half, dtype=torch.float32)
306
- / half
307
- ).to(device=t.device)
308
- args = t[:, None].float() * freqs[None]
309
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
310
- if dim % 2:
311
- embedding = torch.cat(
312
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
313
- )
314
- return embedding
315
-
316
- def forward(self, t):
317
- t_freq = self.timestep_embedding(
318
- t, self.frequency_embedding_size, self.max_period
319
- ).type(self.mlp[0].weight.dtype) # type: ignore
320
- t_emb = self.mlp(t_freq)
321
- return t_emb
322
-
323
-
324
- class EmbedND(nn.Module):
325
- def __init__(self, dim: int, theta: int, axes_dim: list[int]):
326
- super().__init__()
327
- self.dim = dim
328
- self.theta = theta
329
- self.axes_dim = axes_dim
330
-
331
- def forward(self, ids: Tensor) -> Tensor:
332
- n_axes = ids.shape[-1]
333
- emb = torch.cat(
334
- [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
335
- dim=-3,
336
- )
337
-
338
- return emb.unsqueeze(1)
339
-
340
-
341
- class MLPEmbedder(nn.Module):
342
- def __init__(self, in_dim: int, hidden_dim: int):
343
- super().__init__()
344
- self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
345
- self.silu = nn.SiLU()
346
- self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
347
-
348
- def forward(self, x: Tensor) -> Tensor:
349
- return self.out_layer(self.silu(self.in_layer(x)))
350
-
351
-
352
- def rope(pos, dim: int, theta: int):
353
- assert dim % 2 == 0
354
- scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
355
- omega = 1.0 / (theta**scale)
356
- out = torch.einsum("...n,d->...nd", pos, omega)
357
- out = torch.stack(
358
- [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
359
- )
360
- out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
361
- return out.float()
362
-
363
-
364
- def attention_after_rope(q, k, v, pe):
365
- q, k = apply_rope(q, k, pe)
366
-
367
- from .attention import attention
368
-
369
- x = attention(q, k, v, mode="torch")
370
- return x
371
-
372
-
373
- @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
374
- def apply_rope(xq, xk, freqs_cis):
375
- # 将 num_heads 和 seq_len 的维度交换回原函数的处理顺序
376
- xq = xq.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
377
- xk = xk.transpose(1, 2)
378
-
379
- # 将 head_dim 拆分为复数部分(实部和虚部)
380
- xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
381
- xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
382
-
383
- # 应用旋转位置编码(复数乘法)
384
- xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
385
- xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
386
-
387
- # 恢复张量形状并转置回目标维度顺序
388
- xq_out = xq_out.reshape(*xq.shape).type_as(xq).transpose(1, 2)
389
- xk_out = xk_out.reshape(*xk.shape).type_as(xk).transpose(1, 2)
390
-
391
- return xq_out, xk_out
392
-
393
-
394
- @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
395
- def scale_add_residual(
396
- x: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor
397
- ) -> torch.Tensor:
398
- return x * scale + residual
399
-
400
-
401
- @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
402
- def layernorm_and_scale_shift(
403
- x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor
404
- ) -> torch.Tensor:
405
- return torch.nn.functional.layer_norm(x, (x.size(-1),)) * (scale + 1) + shift
406
-
407
-
408
- class SelfAttention(nn.Module):
409
- def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
410
- super().__init__()
411
- self.num_heads = num_heads
412
- head_dim = dim // num_heads
413
-
414
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
415
- self.norm = QKNorm(head_dim)
416
- self.proj = nn.Linear(dim, dim)
417
-
418
- def forward(self, x: Tensor, pe: Tensor) -> Tensor:
419
- qkv = self.qkv(x)
420
- q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
421
- q, k = self.norm(q, k, v)
422
- x = attention_after_rope(q, k, v, pe=pe)
423
- x = self.proj(x)
424
- return x
425
-
426
-
427
- @dataclass
428
- class ModulationOut:
429
- shift: Tensor
430
- scale: Tensor
431
- gate: Tensor
432
-
433
-
434
- class RMSNorm(torch.nn.Module):
435
- def __init__(self, dim: int):
436
- super().__init__()
437
- self.scale = nn.Parameter(torch.ones(dim))
438
-
439
- # @staticmethod
440
- # def rms_norm_fast(x, weight, eps):
441
- # return LigerRMSNormFunction.apply(
442
- # x,
443
- # weight,
444
- # eps,
445
- # 0.0,
446
- # "gemma",
447
- # True,
448
- # )
449
-
450
- @staticmethod
451
- def rms_norm(x, weight, eps):
452
- x_dtype = x.dtype
453
- x = x.float()
454
- rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
455
- return (x * rrms).to(dtype=x_dtype) * weight
456
-
457
- def forward(self, x: Tensor):
458
- # return self.rms_norm_fast(x, self.scale, 1e-6)
459
- return self.rms_norm(x, self.scale, 1e-6)
460
-
461
-
462
- class QKNorm(torch.nn.Module):
463
- def __init__(self, dim: int):
464
- super().__init__()
465
- self.query_norm = RMSNorm(dim)
466
- self.key_norm = RMSNorm(dim)
467
-
468
- def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
469
- q = self.query_norm(q)
470
- k = self.key_norm(k)
471
- return q.to(v), k.to(v)
472
-
473
-
474
- class Modulation(nn.Module):
475
- def __init__(self, dim: int, double: bool):
476
- super().__init__()
477
- self.is_double = double
478
- self.multiplier = 6 if double else 3
479
- self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
480
-
481
- def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
482
- out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
483
- self.multiplier, dim=-1
484
- )
485
-
486
- return (
487
- ModulationOut(*out[:3]),
488
- ModulationOut(*out[3:]) if self.is_double else None,
489
- )
490
-
491
-
492
- class DoubleStreamBlock(nn.Module):
493
- def __init__(
494
- self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
495
- ):
496
- super().__init__()
497
-
498
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
499
- self.num_heads = num_heads
500
- self.hidden_size = hidden_size
501
- self.img_mod = Modulation(hidden_size, double=True)
502
- self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
503
- self.img_attn = SelfAttention(
504
- dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
505
- )
506
-
507
- self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
508
- self.img_mlp = nn.Sequential(
509
- nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
510
- nn.GELU(approximate="tanh"),
511
- nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
512
- )
513
-
514
- self.txt_mod = Modulation(hidden_size, double=True)
515
- self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
516
- self.txt_attn = SelfAttention(
517
- dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
518
- )
519
-
520
- self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
521
- self.txt_mlp = nn.Sequential(
522
- nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
523
- nn.GELU(approximate="tanh"),
524
- nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
525
- )
526
-
527
- def forward(
528
- self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
529
- ) -> tuple[Tensor, Tensor]:
530
- img_mod1, img_mod2 = self.img_mod(vec)
531
- txt_mod1, txt_mod2 = self.txt_mod(vec)
532
-
533
- # prepare image for attention
534
- img_modulated = self.img_norm1(img)
535
- img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
536
- img_qkv = self.img_attn.qkv(img_modulated)
537
- img_q, img_k, img_v = rearrange(
538
- img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
539
- )
540
- img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
541
-
542
- # prepare txt for attention
543
- txt_modulated = self.txt_norm1(txt)
544
- txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
545
- txt_qkv = self.txt_attn.qkv(txt_modulated)
546
- txt_q, txt_k, txt_v = rearrange(
547
- txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
548
- )
549
- txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
550
-
551
- # run actual attention
552
- q = torch.cat((txt_q, img_q), dim=1)
553
- k = torch.cat((txt_k, img_k), dim=1)
554
- v = torch.cat((txt_v, img_v), dim=1)
555
-
556
- attn = attention_after_rope(q, k, v, pe=pe)
557
- txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
558
-
559
- # calculate the img bloks
560
- img = img + img_mod1.gate * self.img_attn.proj(img_attn)
561
- img_mlp = self.img_mlp(
562
- (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
563
- )
564
- img = scale_add_residual(img_mlp, img_mod2.gate, img)
565
-
566
- # calculate the txt bloks
567
- txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
568
- txt_mlp = self.txt_mlp(
569
- (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
570
- )
571
- txt = scale_add_residual(txt_mlp, txt_mod2.gate, txt)
572
- return img, txt
573
-
574
-
575
- class SingleStreamBlock(nn.Module):
576
- """
577
- A DiT block with parallel linear layers as described in
578
- https://arxiv.org/abs/2302.05442 and adapted modulation interface.
579
- """
580
-
581
- def __init__(
582
- self,
583
- hidden_size: int,
584
- num_heads: int,
585
- mlp_ratio: float = 4.0,
586
- qk_scale: float | None = None,
587
- ):
588
- super().__init__()
589
- self.hidden_dim = hidden_size
590
- self.num_heads = num_heads
591
- head_dim = hidden_size // num_heads
592
- self.scale = qk_scale or head_dim**-0.5
593
-
594
- self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
595
- # qkv and mlp_in
596
- self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
597
- # proj and mlp_out
598
- self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
599
-
600
- self.norm = QKNorm(head_dim)
601
-
602
- self.hidden_size = hidden_size
603
- self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
604
-
605
- self.mlp_act = nn.GELU(approximate="tanh")
606
- self.modulation = Modulation(hidden_size, double=False)
607
-
608
- def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
609
- mod, _ = self.modulation(vec)
610
- x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
611
- qkv, mlp = torch.split(
612
- self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
613
- )
614
-
615
- q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
616
- q, k = self.norm(q, k, v)
617
-
618
- # compute attention
619
- attn = attention_after_rope(q, k, v, pe=pe)
620
- # compute activation in mlp stream, cat again and run second linear layer
621
- output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
622
- return scale_add_residual(output, mod.gate, x)
623
-
624
-
625
- class LastLayer(nn.Module):
626
- def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
627
- super().__init__()
628
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
629
- self.linear = nn.Linear(
630
- hidden_size, patch_size * patch_size * out_channels, bias=True
631
- )
632
- self.adaLN_modulation = nn.Sequential(
633
- nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
634
- )
635
-
636
- def forward(self, x: Tensor, vec: Tensor) -> Tensor:
637
- shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
638
- x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
639
- x = self.linear(x)
640
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/model_edit.py DELETED
@@ -1,143 +0,0 @@
1
- import math
2
- from dataclasses import dataclass
3
-
4
- import numpy as np
5
- import torch
6
- from torch import Tensor, nn
7
-
8
- from .connector_edit import Qwen2Connector
9
- from .layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock
10
-
11
-
12
- @dataclass
13
- class Step1XParams:
14
- in_channels: int
15
- out_channels: int
16
- vec_in_dim: int
17
- context_in_dim: int
18
- hidden_size: int
19
- mlp_ratio: float
20
- num_heads: int
21
- depth: int
22
- depth_single_blocks: int
23
- axes_dim: list[int]
24
- theta: int
25
- qkv_bias: bool
26
-
27
-
28
- class Step1XEdit(nn.Module):
29
- """
30
- Transformer model for flow matching on sequences.
31
- """
32
-
33
- def __init__(self, params: Step1XParams):
34
- super().__init__()
35
-
36
- self.params = params
37
- self.in_channels = params.in_channels
38
- self.out_channels = params.out_channels
39
- if params.hidden_size % params.num_heads != 0:
40
- raise ValueError(
41
- f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
42
- )
43
- pe_dim = params.hidden_size // params.num_heads
44
- if sum(params.axes_dim) != pe_dim:
45
- raise ValueError(
46
- f"Got {params.axes_dim} but expected positional dim {pe_dim}"
47
- )
48
- self.hidden_size = params.hidden_size
49
- self.num_heads = params.num_heads
50
- self.pe_embedder = EmbedND(
51
- dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
52
- )
53
- self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
54
- self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
55
- self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
56
- self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
57
-
58
- self.double_blocks = nn.ModuleList(
59
- [
60
- DoubleStreamBlock(
61
- self.hidden_size,
62
- self.num_heads,
63
- mlp_ratio=params.mlp_ratio,
64
- qkv_bias=params.qkv_bias,
65
- )
66
- for _ in range(params.depth)
67
- ]
68
- )
69
-
70
- self.single_blocks = nn.ModuleList(
71
- [
72
- SingleStreamBlock(
73
- self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
74
- )
75
- for _ in range(params.depth_single_blocks)
76
- ]
77
- )
78
-
79
- self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
80
-
81
- self.connector = Qwen2Connector()
82
-
83
- @staticmethod
84
- def timestep_embedding(
85
- t: Tensor, dim, max_period=10000, time_factor: float = 1000.0
86
- ):
87
- """
88
- Create sinusoidal timestep embeddings.
89
- :param t: a 1-D Tensor of N indices, one per batch element.
90
- These may be fractional.
91
- :param dim: the dimension of the output.
92
- :param max_period: controls the minimum frequency of the embeddings.
93
- :return: an (N, D) Tensor of positional embeddings.
94
- """
95
- t = time_factor * t
96
- half = dim // 2
97
- freqs = torch.exp(
98
- -math.log(max_period)
99
- * torch.arange(start=0, end=half, dtype=torch.float32)
100
- / half
101
- ).to(t.device)
102
-
103
- args = t[:, None].float() * freqs[None]
104
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
105
- if dim % 2:
106
- embedding = torch.cat(
107
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
108
- )
109
- if torch.is_floating_point(t):
110
- embedding = embedding.to(t)
111
- return embedding
112
-
113
- def forward(
114
- self,
115
- img: Tensor,
116
- img_ids: Tensor,
117
- txt: Tensor,
118
- txt_ids: Tensor,
119
- timesteps: Tensor,
120
- y: Tensor,
121
- ) -> Tensor:
122
- if img.ndim != 3 or txt.ndim != 3:
123
- raise ValueError("Input img and txt tensors must have 3 dimensions.")
124
-
125
- img = self.img_in(img)
126
- vec = self.time_in(self.timestep_embedding(timesteps, 256))
127
-
128
- vec = vec + self.vector_in(y)
129
- txt = self.txt_in(txt)
130
-
131
- ids = torch.cat((txt_ids, img_ids), dim=1)
132
- pe = self.pe_embedder(ids)
133
-
134
- for block in self.double_blocks:
135
- img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
136
-
137
- img = torch.cat((txt, img), 1)
138
- for block in self.single_blocks:
139
- img = block(img, vec=vec, pe=pe)
140
- img = img[:, txt.shape[1] :, ...]
141
-
142
- img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
143
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
no_cookie.png DELETED

Git LFS Details

  • SHA256: 4ee90a1e41774e2dae54ca436874341e750f2c7a6196b8360aee1952e98066f8
  • Pointer size: 131 Bytes
  • Size of remote file: 162 kB
poster.jpg DELETED
Binary file (65.4 kB)
 
poster_orig.jpg DELETED

Git LFS Details

  • SHA256: 92a4178a56e7fefd7dfd418c675c1ab6b6b2e00e17b45a778a1100ab62f9bfba
  • Pointer size: 131 Bytes
  • Size of remote file: 458 kB