Shaoan commited on
Commit
642f8f3
·
verified ·
1 Parent(s): d2f46ae

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. aligner.py +861 -0
  2. app.py +210 -0
  3. empty_pooled_clip.pt +3 -0
  4. pipeline.py +641 -0
  5. requirements.txt +8 -0
  6. requirements.txt.py +8 -0
  7. text_encoder.py +1188 -0
aligner.py ADDED
@@ -0,0 +1,861 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from refiner import Qwen2Connector
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+
18
+ class MultiHeadSelfAttention(nn.Module):
19
+ def __init__(self, embed_dim=2560, num_heads=20):
20
+ super().__init__()
21
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
22
+
23
+ self.embed_dim = embed_dim
24
+ self.num_heads = num_heads
25
+ self.head_dim = embed_dim // num_heads
26
+
27
+ # Linear projections for Q, K, V
28
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
29
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
30
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
31
+
32
+ # Output projection
33
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
34
+
35
+ self.scale = self.head_dim ** -0.5
36
+
37
+ def forward(self, x, mask=None, return_attention=True):
38
+ """
39
+ Args:
40
+ x: Input tensor of shape [b, seq_len, embed_dim]
41
+ mask: Attention mask of shape [b, seq_len], where 1 means attend, 0 means ignore
42
+ return_attention: Whether to return attention weights
43
+
44
+ Returns:
45
+ output: [b, seq_len, embed_dim]
46
+ attn_weights: [b*num_heads, seq_len, seq_len] (if return_attention=True)
47
+ """
48
+ b, seq_len, embed_dim = x.shape
49
+
50
+ # Project to Q, K, V
51
+ Q = self.q_proj(x) # [b, seq_len, embed_dim]
52
+ K = self.k_proj(x) # [b, seq_len, embed_dim]
53
+ V = self.v_proj(x) # [b, seq_len, embed_dim]
54
+
55
+ # Reshape and transpose for multi-head attention
56
+ # [b, seq_len, embed_dim] -> [b, seq_len, num_heads, head_dim] -> [b, num_heads, seq_len, head_dim]
57
+ Q = Q.view(b, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
58
+ K = K.view(b, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
59
+ V = V.view(b, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
60
+
61
+ # Reshape for batch computation: [b, num_heads, seq_len, head_dim] -> [b*num_heads, seq_len, head_dim]
62
+ Q = Q.reshape(b * self.num_heads, seq_len, self.head_dim)
63
+ K = K.reshape(b * self.num_heads, seq_len, self.head_dim)
64
+ V = V.reshape(b * self.num_heads, seq_len, self.head_dim)
65
+
66
+ # Compute attention scores: Q @ K^T
67
+ attn_scores = torch.bmm(Q, K.transpose(1, 2)) * self.scale # [b*num_heads, seq_len, seq_len]
68
+
69
+ # Apply mask if provided
70
+ if mask is not None:
71
+ # Key mask (column masking): which keys can be attended to
72
+ key_mask = mask.unsqueeze(1).unsqueeze(2) # [b, 1, 1, seq_len]
73
+
74
+ # Query mask (row masking): which queries are valid
75
+ query_mask = mask.unsqueeze(1).unsqueeze(3) # [b, 1, seq_len, 1]
76
+
77
+ # Combine both masks: a position can attend only if BOTH query and key are valid
78
+ # Shape: [b, 1, seq_len, seq_len]
79
+ final_mask = query_mask.bool() & key_mask.bool() # Broadcasting handles the dimensions
80
+
81
+ # Expand to all heads and reshape
82
+ final_mask = final_mask.expand(b, self.num_heads, seq_len, seq_len)
83
+ final_mask = final_mask.reshape(b * self.num_heads, seq_len, seq_len)
84
+
85
+ attn_scores = attn_scores.masked_fill(~final_mask, float('-inf'))
86
+
87
+ # Apply softmax
88
+ attn_weights = F.softmax(attn_scores, dim=-1) # [b*num_heads, seq_len, seq_len]
89
+
90
+ # Handle NaN from softmax (when entire row is -inf)
91
+ attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
92
+
93
+ # Apply attention to values
94
+ attn_output = torch.bmm(attn_weights, V) # [b*num_heads, seq_len, head_dim]
95
+
96
+ # Reshape back: [b*num_heads, seq_len, head_dim] -> [b, num_heads, seq_len, head_dim]
97
+ attn_output = attn_output.view(b, self.num_heads, seq_len, self.head_dim)
98
+
99
+ # Transpose and reshape: [b, num_heads, seq_len, head_dim] -> [b, seq_len, num_heads, head_dim] -> [b, seq_len, embed_dim]
100
+ attn_output = attn_output.transpose(1, 2).contiguous().view(b, seq_len, embed_dim)
101
+
102
+ # Final output projection
103
+ output = self.out_proj(attn_output) # [b, seq_len, embed_dim]
104
+
105
+ if return_attention:
106
+ return output, attn_weights # attn_weights is [b*num_heads, seq_len, seq_len]
107
+ else:
108
+ return output
109
+
110
+
111
+ class ConceptAligner222(nn.Module):
112
+ def __init__(self, custom_pool=1, input_dim=2560, hidden_size=2560):
113
+ super().__init__()
114
+ if input_dim == 2560:
115
+ hidden_size = 2560
116
+ self.num_heads = 20
117
+ self.model_class = 'gemma3'
118
+ depth = 2
119
+ identity_mapping = False
120
+
121
+ elif input_dim == 4096:
122
+ hidden_size = 3072
123
+ self.num_heads = 24
124
+ self.model_class = 't5'
125
+ depth = 1
126
+ identity_mapping = True
127
+
128
+ self.text_connector = Qwen2Connector(in_channels=input_dim, hidden_size=hidden_size, heads_num=self.num_heads,
129
+ depth=depth, identity_init=identity_mapping)
130
+ self.final_proj = nn.Sequential(nn.Linear(hidden_size, 4096), nn.SiLU(), nn.Linear(4096, 4096))
131
+ self.resampler = MultiHeadSelfAttention(embed_dim=hidden_size, num_heads=self.num_heads)
132
+ empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu')
133
+ self.register_buffer('empty_pooled_clip', empty_pooled_clip)
134
+ self.learnable_scale_norm = nn.Parameter(torch.ones([1, 1, 1]) * 0.01, requires_grad=True)
135
+ self.proj_norm = nn.LayerNorm(hidden_size)
136
+ self.custom_pool = custom_pool
137
+ if self.custom_pool:
138
+ self.clip_proj = nn.Sequential(nn.Linear(hidden_size, hidden_size * 3), nn.SiLU(),
139
+ nn.Linear(hidden_size * 3, 768))
140
+ self.clip_norm = nn.LayerNorm(768)
141
+ print('Using custom pooling for CLIP features.')
142
+
143
+ @property
144
+ def dtype(self):
145
+ """Return the dtype of the model parameters."""
146
+ # return next(self.parameters()).dtype
147
+ return torch.bfloat16
148
+
149
+ @property
150
+ def device(self):
151
+ """Return the device of the model parameters."""
152
+ # return next(self.parameters()).device
153
+ return self.empty_pooled_clip.device
154
+
155
+ def forward(self, text_features, text_mask, is_training=False, img_seq_len=1024):
156
+ text_features = self.text_connector(text_features, mask=text_mask,
157
+ mean_start_id=2 if self.model_class == 'gemma' else 0)
158
+ text_features = self.proj_norm(text_features)
159
+ aligned_features, attn = self.resampler(text_features, mask=text_mask, return_attention=True)
160
+ if is_training:
161
+ learnable_scale = torch.clip(self.learnable_scale_norm, -1.0, 1.0)
162
+ visual_concepts = aligned_features + learnable_scale * torch.randn_like(aligned_features)
163
+ else:
164
+ visual_concepts = aligned_features
165
+ prompt_embeds = self.final_proj(visual_concepts)
166
+ # prompt_embeds = text_features
167
+ if self.custom_pool:
168
+ mean_features = (aligned_features * text_mask.unsqueeze(-1)).sum(dim=1) / (
169
+ text_mask.sum(dim=1, keepdim=True) + 1e-8)
170
+ pooled_prompt_embeds = self.clip_proj(mean_features)
171
+ pooled_prompt_embeds = self.clip_norm(pooled_prompt_embeds)
172
+ else:
173
+ pooled_prompt_embeds = self.empty_pooled_clip.expand(text_features.shape[0], -1)
174
+ dtype = prompt_embeds.dtype
175
+ device = prompt_embeds.device
176
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
177
+
178
+ total_seq_len = img_seq_len + prompt_embeds.shape[1]
179
+ text_seq_len = text_mask.shape[1]
180
+ attention_mask = torch.zeros(
181
+ len(text_features), 1, 1, total_seq_len,
182
+ device=text_mask.device,
183
+ dtype=text_mask.dtype
184
+ )
185
+ # Fill in text portion: where text_mask==0, set to -inf
186
+ attention_mask[:, :, :, :text_seq_len] = (1 - text_mask).unsqueeze(1).unsqueeze(2) * -10000.0
187
+
188
+ entropy = -(attn * torch.log(attn + 1e-8)).sum(dim=-1)
189
+ mask_expanded = text_mask.unsqueeze(1).repeat(1, self.num_heads, 1)
190
+ mask_expanded = mask_expanded.reshape(len(text_features) * self.num_heads, text_seq_len)
191
+ valid_entropy = entropy[mask_expanded.bool()]
192
+
193
+ return prompt_embeds, attention_mask, pooled_prompt_embeds, text_ids, valid_entropy
194
+ # return prompt_embeds, pooled_prompt_embeds, text_ids, None
195
+
196
+
197
+ import torch
198
+ import torch.nn as nn
199
+
200
+
201
+ class RMSNorm(nn.Module):
202
+ def __init__(self, dim: int, eps: float = 1e-6):
203
+ super().__init__()
204
+ self.eps = eps
205
+ self.weight = nn.Parameter(torch.ones(dim))
206
+
207
+ def _norm(self, x):
208
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
209
+
210
+ def forward(self, x):
211
+ output = self._norm(x.float()).type_as(x)
212
+ return output * self.weight
213
+
214
+
215
+ class AdaLayerNorm(nn.Module):
216
+ def __init__(self, embedding_dim: int, time_embedding_dim=4096):
217
+ super().__init__()
218
+
219
+ if time_embedding_dim is None:
220
+ time_embedding_dim = embedding_dim
221
+
222
+ self.silu = nn.SiLU()
223
+ self.linear = nn.Linear(time_embedding_dim, 2 * embedding_dim, bias=True)
224
+ nn.init.normal_(self.linear.weight, mean=0, std=0.02)
225
+ nn.init.zeros_(self.linear.bias)
226
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
227
+
228
+ def forward(
229
+ self, x: torch.Tensor, timestep_embedding: torch.Tensor
230
+ ) -> tuple[torch.Tensor, torch.Tensor]:
231
+ emb = self.linear(self.silu(timestep_embedding))
232
+ shift, scale = emb.unsqueeze(1).chunk(2, dim=-1)
233
+ x = self.norm(x) * (1 + scale) + shift
234
+ return x
235
+
236
+
237
+ class GateMLP(nn.Module):
238
+ def __init__(self, gate_mode='soft', input_dim=64, hidden_dim=1024):
239
+ super().__init__()
240
+ self.gate_mode = gate_mode
241
+ hidden_dim = max(input_dim, min(hidden_dim, 512))
242
+ hidden_dim = 512
243
+ self.input_norm = nn.LayerNorm(4096)
244
+
245
+ self.norm0 = nn.LayerNorm(input_dim)
246
+ self.linear1 = nn.Linear(input_dim, hidden_dim)
247
+ self.activation1 = nn.GELU()
248
+ self.linear2 = nn.Linear(hidden_dim+4096, hidden_dim)
249
+ self.activation2 = nn.GELU()
250
+ self.linear3 = nn.Linear(hidden_dim+4096, hidden_dim)
251
+ self.activation3 = nn.GELU()
252
+ self.final_linear = nn.Linear(hidden_dim, 1)
253
+
254
+ nn.init.xavier_uniform_(self.linear1.weight)
255
+ nn.init.zeros_(self.linear1.bias)
256
+
257
+ nn.init.xavier_uniform_(self.linear2.weight)
258
+ nn.init.zeros_(self.linear2.bias)
259
+
260
+ nn.init.xavier_uniform_(self.linear3.weight)
261
+ nn.init.zeros_(self.linear3.bias)
262
+
263
+ nn.init.zeros_(self.final_linear.weight)
264
+ bias_val = 0.0 if 'soft' in gate_mode else 1.0
265
+ nn.init.constant_(self.final_linear.bias, bias_val)
266
+
267
+ def forward(self, x):
268
+ y = x.transpose(1, 2).flatten(2)
269
+ y = self.input_norm(y.detach()).unsqueeze(1).repeat(1, x.shape[1],1,1)
270
+ x = self.linear1(self.norm0(x.detach()))
271
+ x = self.activation1(x)
272
+ x = self.linear2(torch.cat([x, y], dim=-1))
273
+ x = self.activation2(x)
274
+ x = self.linear3(torch.cat([x,y], dim=-1))
275
+ x = self.activation3(x)
276
+ x = self.final_linear(x)
277
+ return x
278
+
279
+
280
+ class CrossAttentionWithInfluence(nn.Module):
281
+ def __init__(self, d_model=4096, num_heads=32, gate_mode='hard'):
282
+ super().__init__()
283
+ self.d_model = d_model
284
+ self.num_heads = num_heads
285
+ self.head_dim = d_model // num_heads
286
+ self.gate_mode = gate_mode
287
+
288
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
289
+
290
+ # Linear projections for Q, K, V
291
+ # self.q_proj = nn.Linear(d_model, d_model)
292
+ # self.k_proj = nn.Linear(d_model, d_model)
293
+ self.v_proj = nn.Linear(d_model, d_model)
294
+ self.out_proj = nn.Linear(d_model, d_model)
295
+
296
+ # nn.init.normal_(self.q_proj.weight, mean=0, std=0.02)
297
+ # nn.init.normal_(self.k_proj.weight, mean=0, std=0.02)
298
+ # nn.init.zeros_(self.q_proj.bias)
299
+ # nn.init.zeros_(self.k_proj.bias)
300
+ nn.init.eye_(self.out_proj.weight)
301
+ nn.init.zeros_(self.out_proj.bias)
302
+ nn.init.eye_(self.v_proj.weight)
303
+ nn.init.zeros_(self.v_proj.bias)
304
+
305
+ self.mask_mlp = GateMLP(input_dim=d_model // num_heads, hidden_dim=1024, gate_mode=gate_mode)
306
+
307
+ self.scale = self.head_dim ** -0.5
308
+ # self.learnable_scale_norm = nn.Parameter(torch.ones([1, 1,1,1])*0.01, requires_grad=True)
309
+
310
+ self.rec_mlp = nn.Sequential(nn.Linear(4096, 4096), nn.SiLU(),
311
+ nn.Linear(4096, 4096), nn.SiLU(),
312
+ nn.Linear(4096, 4096)
313
+ )
314
+
315
+ def forward(self, x, y, y_mask, temperature=None, threshold=None, topk=None):
316
+ """
317
+ Args:
318
+ x: shared embedding [b, 300, 4096]
319
+ y: changing embedding [b, 300, 4096]
320
+
321
+ Returns:
322
+ output: [b, 300, 4096]
323
+ y_influence: [b, 32, 300, 300] - influence from y to x
324
+ """
325
+ b, seq_len_x, d_model = x.shape
326
+ b, seq_len_y, d_model_y = y.shape
327
+
328
+ """
329
+ # Q from x only
330
+ Q = self.q_proj(x) # [b, 300, 4096]
331
+ seq_len = Q.shape[1]
332
+
333
+ # K, V from concatenation of [x, y]
334
+ K = self.k_proj(x) # [b, 300, 4096]
335
+ # Reshape for multi-head attention
336
+ Q = Q.view(b, Q.shape[1], self.num_heads, self.head_dim).transpose(1, 2) # [b, 32, 300, 128]
337
+ K = K.view(b, K.shape[1], self.num_heads, self.head_dim).transpose(1, 2) # [b, 32, 600, 128]
338
+
339
+ """
340
+ V = self.v_proj(y) # [b, 300, 4096]
341
+ shared_V = self.v_proj(x) # [b, 300, 4096]
342
+
343
+ textual_concepts = V.view(b, V.shape[1], self.num_heads, self.head_dim).transpose(1, 2) # [b, 32, 300, 128]
344
+ shared_concepts = shared_V.view(b, shared_V.shape[1], self.num_heads, self.head_dim).transpose(1,
345
+ 2) # [b, 32, 300, 128]
346
+ expand_y_mask = y_mask.unsqueeze(1).unsqueeze(-1) # [b, 1, 300, 1]
347
+ # Compute attention scores
348
+ """
349
+ attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # [b, 32, 300, 300]
350
+ attn_weights = F.softmax(attn_scores, dim=-1) # [b, 32, 300, 300]
351
+
352
+ # Compute output
353
+ attn_output = torch.matmul(attn_weights, textual_concepts) # [b, 32, 300, 128]
354
+ """
355
+
356
+ diagonal_influence = self.mask_mlp((textual_concepts))
357
+ if 'soft' in self.gate_mode:
358
+ diagonal_influence = 2 * (torch.sigmoid(diagonal_influence * temperature)) # [b, 32, 300, 1]
359
+ diagonal_influence = (diagonal_influence > 0.1).to(
360
+ diagonal_influence.dtype) * diagonal_influence # Thresholding
361
+ soft_influence = diagonal_influence
362
+ else:
363
+ soft_influence = torch.sigmoid(diagonal_influence * temperature)
364
+ if threshold is None:
365
+ threshold = 0.5
366
+ else:
367
+ print('Using custom threshold for influence gating:', threshold)
368
+ hard_influence = (soft_influence >= threshold)
369
+ diagonal_influence = hard_influence + soft_influence - soft_influence.detach() # Straight-through estimator
370
+
371
+ if topk is not None:
372
+ print(diagonal_influence.shape, ' <<< shape before topk ')
373
+ top_k_values, top_k_indices = torch.topk(diagonal_influence, topk, dim=1)
374
+ result = torch.zeros_like(diagonal_influence)
375
+ result.scatter_(1, top_k_indices, top_k_values)
376
+ diagonal_influence = result
377
+ print('Applied top-k sparsification on influence gates with k=', topk)
378
+
379
+ diagonal_output = textual_concepts * diagonal_influence + shared_concepts * (
380
+ 1 - diagonal_influence) # [b, 32, 300, 128]
381
+ da,db,dc,dd = diagonal_output.shape
382
+ rec_diagonal = self.rec_mlp(diagonal_output.transpose(1,2).flatten(2)[y_mask.bool()].to(x.dtype))
383
+ tgt_diagonal = y[y_mask.bool()]
384
+
385
+ diagonal_output = expand_y_mask * diagonal_output + (1 - expand_y_mask) * shared_concepts # [b, 32, 300, 128]
386
+
387
+ mask_bool_expanded = expand_y_mask.expand_as(diagonal_influence).bool() # [b, 32, 300, 1]
388
+ meaningful_gates = diagonal_influence[mask_bool_expanded]
389
+ soft_meaningful_gate = soft_influence[mask_bool_expanded]
390
+
391
+
392
+ # full_output = self.learnable_scale_norm*attn_output + diagonal_output # [b, 32, 300, 128]
393
+ full_output = diagonal_output.to(x.dtype)
394
+
395
+ # Reshape back
396
+ full_output = full_output.transpose(1, 2).contiguous().view(b, y.shape[1], d_model) # [b, 300, 4096]
397
+ full_output = full_output # Residual connection
398
+
399
+ # Final output projection
400
+ output = self.out_proj(full_output) # [b, 300, 4096]
401
+
402
+ return output, diagonal_influence.squeeze(-1).transpose(1, 2), meaningful_gates, soft_meaningful_gate, rec_diagonal, tgt_diagonal
403
+
404
+
405
+
406
+
407
+
408
+
409
+
410
+ def init_weights_gaussian(model, mean=0.0, std=0.02):
411
+ """
412
+ Initialize all nn.Linear layers in the model:
413
+ - weights with Gaussian(mean, std)
414
+ - biases to 0
415
+ """
416
+ for m in model.modules():
417
+ if isinstance(m, nn.Linear):
418
+ nn.init.normal_(m.weight, mean=mean, std=std)
419
+ if m.bias is not None:
420
+ nn.init.constant_(m.bias, 0.0)
421
+
422
+ class ConceptAligner(nn.Module):
423
+ def __init__(self, per_dim=4):
424
+ super().__init__()
425
+ empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu')
426
+ self.register_buffer('empty_pooled_clip', empty_pooled_clip)
427
+
428
+ test_eps = torch.randn([1, 300, per_dim], dtype=torch.bfloat16).to('cpu')*0.7
429
+ self.register_buffer('test_eps', test_eps)
430
+
431
+ self.init_proj = nn.Sequential(nn.Linear(768, 300*16), nn.SiLU())
432
+ self.proj = nn.Sequential(nn.Linear(16, 1024), nn.SiLU(),
433
+ nn.Linear(1024, 1024), nn.SiLU())
434
+ self.text_proj = nn.Sequential(nn.Linear(4096, 1024), nn.SiLU(),
435
+ nn.Linear(1024, 1024), nn.SiLU())
436
+ self.proj_mu = nn.Sequential(nn.Linear(1024, per_dim))
437
+ self.proj_logvar = nn.Sequential(nn.Linear(1024, per_dim))
438
+
439
+ self.eps_proj = nn.Sequential(nn.Linear(per_dim, 1024), nn.SiLU(),
440
+ nn.LayerNorm(1024),
441
+ nn.Linear(1024, 4096))
442
+
443
+ init_weights_gaussian(self, mean=0.0, std=0.02)
444
+ torch.nn.init.constant_(self.eps_proj[-1].weight, 0.0)
445
+ torch.nn.init.constant_(self.eps_proj[-1].bias, 0.0)
446
+
447
+
448
+ @property
449
+ def dtype(self):
450
+ """Return the dtype of the model parameters."""
451
+ # return next(self.parameters()).dtype
452
+ return torch.bfloat16
453
+
454
+ @property
455
+ def device(self):
456
+ """Return the device of the model parameters."""
457
+ # return next(self.parameters()).device
458
+ return self.empty_pooled_clip.device
459
+
460
+ def forward(self, text_features, image_features=None, eps=None):
461
+
462
+ #return text_features, None, self.empty_pooled_clip.expand(text_features.shape[0], -1), torch.zeros(text_features.shape[1], 3).to(device=text_features.device, dtype=text_features.dtype), {'mu': torch.zeros([1,300,1], device=text_features.device, dtype=text_features.dtype), 'logvar': torch.zeros([1,300,1], device=text_features.device, dtype=text_features.dtype)}
463
+
464
+ dtype = text_features.dtype
465
+ device = text_features.device
466
+
467
+ if image_features is not None:
468
+ visual_hidden = self.proj(self.init_proj(image_features).view(len(image_features), 300, -1))
469
+ text_hidden = self.text_proj(text_features.detach())
470
+ hidden = visual_hidden - text_hidden
471
+ mu = self.proj_mu(hidden)
472
+ logvar = self.proj_logvar(hidden)
473
+ eps = mu + torch.exp(0.5 * logvar) * torch.randn_like(mu)
474
+ else:
475
+ if eps is None:
476
+ eps = self.test_eps.to(device=device, dtype=dtype)
477
+ mu = torch.zeros_like(eps)
478
+ logvar = torch.zeros_like(eps)
479
+
480
+ proj_eps = self.eps_proj(eps)
481
+ prompt_embeds = text_features + proj_eps
482
+ pooled_prompt_embeds = self.empty_pooled_clip.expand(text_features.shape[0], -1)
483
+
484
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
485
+ aux_info = {
486
+ 'mu': mu,
487
+ 'logvar': logvar,
488
+ 'eps': eps
489
+ }
490
+
491
+ return prompt_embeds, None, pooled_prompt_embeds, text_ids, aux_info
492
+
493
+
494
+
495
+
496
+
497
+
498
+ if __name__ == '__main__':
499
+ from transformers import AutoProcessor
500
+ from diffusers import FluxPipeline
501
+ import os
502
+ from PIL import Image
503
+ def create_image_grid(images, cols):
504
+ rows = (len(images) + cols - 1) // cols
505
+ w, h = images[0].size
506
+ grid = Image.new('RGB', (cols * w, rows * h))
507
+ for i, img in enumerate(images):
508
+ grid.paste(img, (i % cols * w, i // cols * h))
509
+ return grid
510
+
511
+ dim = 4096
512
+ num_heads = 32
513
+ dtype = torch.bfloat16
514
+ model = ConceptAligner().to('cuda').to(dtype)
515
+ x = torch.randn([5, 300, dim]).to('cuda').to(dtype)
516
+ y = torch.randn([5, 300, dim]).to('cuda').to(dtype)
517
+ i = torch.randn([5,768]).to('cuda').to(dtype)
518
+ y[1] = y[0]
519
+ m = torch.ones([5, 300]).to('cuda').to(dtype)
520
+ m[:3,:128] = 0
521
+ prompt_embeds, _, pooled_prompt_embeds, text_ids, aux_info = model(x, i)
522
+ print(prompt_embeds.shape, pooled_prompt_embeds.shape, text_ids.shape)
523
+ print(prompt_embeds.shape, ' ', pooled_prompt_embeds.shape, ' ', text_ids.shape)
524
+ for k in aux_info:
525
+ print(k, aux_info[k].shape, aux_info[k].min(), aux_info[k].max(), aux_info[k].mean())
526
+
527
+ from text_encoder import LoraT5Embedder
528
+ from datasets import load_dataset
529
+ dataset = load_dataset("facebook/emu_edit_test_set", split='validation[:200]')
530
+ item = dataset[0:4]
531
+ another_item = dataset[0:4]
532
+ from diffusers.models.normalization import RMSNorm
533
+ clip_processor = AutoProcessor.from_pretrained("./clip-vit-large-patch14")
534
+ clip_images = clip_processor(images=item['image'], return_tensors="pt").pixel_values.to('cuda:0').to(dtype)
535
+ texts = []
536
+ texts.append("""A heartwarming 3D rendered scene of
537
+ an elderly farmer and a tiny orange
538
+ kitten. The farmer, with a gentle smile,
539
+ walks alongside the kitten in a lush,
540
+ green garden filled with thriving plants,
541
+ showcasing a fruitful harvest. The
542
+ intricate details of the overalls and the
543
+ farmer's worn, weathered face tell a
544
+ story of years spent tending to the land. the farmer is wearing a blue shirt""")
545
+ texts.append("""A unique, intricately detailed creature
546
+ resembling a reptile, possibly a lizard or
547
+ a gecko. It has a vibrant blue and green
548
+ scaled body, with large, round, and
549
+ expressive eyes that are a deep shade of
550
+ blue. The backdrop is a
551
+ soft, blurred forest setting, suggesting a
552
+ serene and mystical ambiance. the creature is wearing a golden crown""")
553
+ texts.append("""Deep in the enchanted forest lives a woman
554
+ who is the moon fairy. Her long blonde hair
555
+ shines in the starlight, tangled with her flowers
556
+ that glow with a soft blue glow. Her eyes are
557
+ the color of the night and shine with the magic
558
+ of the night. The fairy wears a dress made of
559
+ moon petals, woven with threads of moonlight
560
+ that shine with an iridescent glow, a crown of
561
+ stars adorns her head, shining with the light of
562
+ the full moon that illuminates the forest. Her
563
+ wings are translucent like glass, with a pale
564
+ glow reminiscent of the glow of the moon. HD,
565
+ 6K, photo, cinematic, poster""")
566
+
567
+ texts.append(
568
+ """In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a clear blue sky outside, with the silhouettes of trees swaying gently in the distance. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and homey, capturing a tranquil afternoon moment of quiet observation.""")
569
+
570
+ text_encoder = LoraT5Embedder(device='cuda').to(dtype)
571
+ text_features, _, _, _, image_features, _ = text_encoder(texts, clip_images)
572
+ print(text_features.shape, image_features.shape, ' >>>>>>>>> text input')
573
+ images = []
574
+ pipe = FluxPipeline.from_pretrained("./FLUX.1-dev", dtype=torch.bfloat16, text_encoder=None).to(torch.bfloat16)
575
+ pipe.to('cuda')
576
+
577
+ for txt_feat, img_feat in zip(text_features, image_features):
578
+
579
+ prompt_embeds, _, pooled_prompt_embeds, text_ids, aux_info = model(txt_feat.unsqueeze(0), img_feat.unsqueeze(0))
580
+ image = pipe(
581
+ prompt_embeds=prompt_embeds,
582
+ pooled_prompt_embeds=pooled_prompt_embeds,
583
+ height=512,
584
+ width=512,
585
+ guidance_scale=3.5,
586
+ num_inference_steps=20,
587
+ max_sequence_length=512,
588
+ generator=torch.Generator("cuda").manual_seed(1995),
589
+ ).images[0]
590
+ images.append(image)
591
+
592
+ aligned_image = create_image_grid(images, cols=len(images) // 2)
593
+ os.makedirs('samples', exist_ok=True)
594
+ aligned_image.save("samples/image%.jpg")
595
+
596
+
597
+ raise SystemExit
598
+
599
+ influence_matrix = aux_info['influence']
600
+ bin_influence_matrix = (influence_matrix > 0.1).float()
601
+ mean_alive = bin_influence_matrix.sum(dim=-1).mean()
602
+ max_alive = bin_influence_matrix.sum(dim=-1).max()
603
+ min_alive = bin_influence_matrix.sum(dim=-1).min()
604
+ max_token_alive = ((bin_influence_matrix.sum(dim=-1) > 0).float().sum(dim=-1)).max()
605
+ mean_token_alive = ((bin_influence_matrix.sum(dim=-1) > 0).float().sum(dim=-1)).mean()
606
+ min_token_alive = ((bin_influence_matrix.sum(dim=-1) > 0).float().sum(dim=-1)).min()
607
+
608
+ print(
609
+ f"Mean alive heads per token: {mean_alive:.2f}, Max alive heads per token: {max_alive:.2f}, Min alive heads per token: {min_alive:.2f}")
610
+ print(
611
+ f"Mean alive tokens: {mean_token_alive:.2f}, Max alive tokens: {max_token_alive:.2f}, Min alive tokens: {min_token_alive:.2f}")
612
+
613
+ import os
614
+
615
+ CHECKPOINT_PATH = 'runs/00393/checkpoint-6000'
616
+ from safetensors.torch import load_file
617
+
618
+ # Load adapter (model.safetensors)
619
+ adapter_path = os.path.join(CHECKPOINT_PATH, "model_1.safetensors")
620
+ if os.path.exists(adapter_path):
621
+ adapter_state = load_file(adapter_path)
622
+ model.load_state_dict(adapter_state, strict=True)
623
+ print("Adapter loaded successfully!")
624
+
625
+ print(model.influence_net.v_proj.weight, ' <<< weight ')
626
+ print(model.influence_net.v_proj.bias, ' <<< bias ')
627
+ print(model.influence_net.out_proj.weight, ' <<< out weight ')
628
+ print(model.influence_net.out_proj.bias, ' <<< out bias ')
629
+ print(model.influence_net.mask_mlp.linear3.weight, ' <<< gate weight 3 ')
630
+ print(model.influence_net.mask_mlp.linear3.bias, ' <<< gate bias ')
631
+
632
+ z = torch.randn([3, num_heads, 300, 4096 // num_heads]).to('cuda').to(dtype)
633
+ gate_values = model.influence_net.mask_mlp(z)
634
+ gate_values = 2 * (torch.sigmoid(gate_values))
635
+
636
+ print(gate_values, ' <<< gate values ', gate_values.shape, ' ', torch.mean(gate_values))
637
+
638
+ from diffusers import FluxPipeline
639
+ from PIL import Image
640
+
641
+
642
+
643
+
644
+ reserved_memory = torch.cuda.memory_reserved(0) / (1024 ** 3)
645
+ print(f"Reserved GPU memory: {reserved_memory:.2f} GB")
646
+
647
+ from transformers import T5EncoderModel, T5Tokenizer, CLIPTokenizer, CLIPTextModel
648
+ import torch
649
+ from text_encoder import LoraT5Embedder
650
+
651
+
652
+ text_encoder = LoraT5Embedder(device='cuda').to(torch.bfloat16)
653
+ texts = []
654
+ texts.append("""A heartwarming 3D rendered scene of
655
+ an elderly farmer and a tiny orange
656
+ kitten. The farmer, with a gentle smile,
657
+ walks alongside the kitten in a lush,
658
+ green garden filled with thriving plants,
659
+ showcasing a fruitful harvest. The
660
+ intricate details of the overalls and the
661
+ farmer's worn, weathered face tell a
662
+ story of years spent tending to the land. the farmer is wearing a blue shirt""")
663
+ texts.append("""A unique, intricately detailed creature
664
+ resembling a reptile, possibly a lizard or
665
+ a gecko. It has a vibrant blue and green
666
+ scaled body, with large, round, and
667
+ expressive eyes that are a deep shade of
668
+ blue. The backdrop is a
669
+ soft, blurred forest setting, suggesting a
670
+ serene and mystical ambiance. the creature is wearing a golden crown""")
671
+ texts.append("""Deep in the enchanted forest lives a woman
672
+ who is the moon fairy. Her long blonde hair
673
+ shines in the starlight, tangled with her flowers
674
+ that glow with a soft blue glow. Her eyes are
675
+ the color of the night and shine with the magic
676
+ of the night. The fairy wears a dress made of
677
+ moon petals, woven with threads of moonlight
678
+ that shine with an iridescent glow, a crown of
679
+ stars adorns her head, shining with the light of
680
+ the full moon that illuminates the forest. Her
681
+ wings are translucent like glass, with a pale
682
+ glow reminiscent of the glow of the moon. HD,
683
+ 6K, photo, cinematic, poster""")
684
+
685
+ texts.append(
686
+ """In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a clear blue sky outside, with the silhouettes of trees swaying gently in the distance. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and homey, capturing a tranquil afternoon moment of quiet observation.""")
687
+ texts.append(
688
+ """In the image, a majestic white horse gallops across a misty meadow at sunrise. Its mane and tail flow freely in the golden light, and the air glows softly with early morning haze. The horse’s body is bare, revealing the natural curve of its muscles and the sheen of its coat. Dew sparkles on the grass beneath its hooves, and the distant trees fade into pale gold mist. The scene conveys freedom, grace, and quiet power.""")
689
+ INDEX = 0
690
+ text = texts[INDEX]
691
+ with torch.no_grad():
692
+ floral_embeds, _,_,_,_,attn_mask = text_encoder(text, )
693
+ print(attn_mask.shape, ' >>>> ', attn_mask)
694
+ print(floral_embeds.shape, shared_embeds.shape, ' >>>> floral ')
695
+ nopad_floral_embeds, nopad_shared_embeds, nopad_attn_mask = text_encoder(text, padding=False)
696
+ print(floral_embeds.shape, shared_embeds.shape, ' >>>> floral ')
697
+
698
+ """
699
+ _,_,_,_,aux_info = model(floral_embeds, shared_embeds, attn_mask, is_training=False)
700
+ print(aux_info['meaningful_influence'].shape, ' <<< influence shape ', aux_info['meaningful_influence'][:100],' ',torch.mean(aux_info['meaningful_influence']))
701
+ floral_embeds, shared_embeds, attn_mask = text_encoder([""], padding='max_length')
702
+ _,_,_,_,aux_info = model(floral_embeds, shared_embeds, attn_mask, is_training=False)
703
+ print(aux_info['meaningful_influence'].shape, ' <<< empty influence shape ', aux_info['meaningful_influence'],' ',torch.mean(aux_info['meaningful_influence']))
704
+ raise SystemExit
705
+ """
706
+
707
+ text2s = []
708
+ text2s.append("""A heartwarming 3D rendered scene of
709
+ an elderly farmer and a tiny orange
710
+ kitten. The farmer, with a gentle smile,
711
+ walks alongside the kitten in a lush,
712
+ green garden filled with thriving plants,
713
+ showcasing a fruitful harvest. The
714
+ intricate details of the overalls and the
715
+ farmer's worn, weathered face tell a
716
+ story of years spent tending to the land. the farmer is wearing a red shirt""")
717
+ text2s.append("""A unique, intricately detailed creature
718
+ resembling a reptile, possibly a lizard or
719
+ a gecko. It has a vibrant blue and green
720
+ scaled body, with large, round, and
721
+ expressive eyes that are a deep shade of
722
+ blue. The backdrop is a
723
+ soft, blurred forest setting, suggesting a
724
+ serene and mystical ambiance. the creature is wearing a floral crown""")
725
+ text2s.append("""Deep in the enchanted forest lives a woman
726
+ who is the moon fairy. Her long black hair
727
+ shines in the starlight, tangled with her flowers
728
+ that glow with a soft blue glow. Her eyes are
729
+ the color of the night and shine with the magic
730
+ of the night. The fairy wears a dress made of
731
+ moon petals, woven with threads of moonlight
732
+ that shine with an iridescent glow, a crown of
733
+ stars adorns her head, shining with the light of
734
+ the full moon that illuminates the forest. Her
735
+ wings are translucent like glass, with a pale
736
+ glow reminiscent of the glow of the moon. HD,
737
+ 6K, photo, cinematic, poster""")
738
+ text2s.append(
739
+ """In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a gray, rainy sky outside, with raindrops streaking down the glass and blurred trees beyond. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and introspective, capturing a cozy moment of quiet observation during a rainy afternoon.""")
740
+ text2s.append(
741
+ """In the image, a majestic white horse gallops across a misty meadow at sunrise. Its mane and tail flow freely in the golden light, and the air glows softly with early morning haze. The horse’s body is adorned with a bright red saddle, contrasting sharply against its white coat. Dew sparkles on the grass beneath its hooves, and the distant trees fade into pale gold mist. The scene conveys freedom, grace, and a striking touch of color that adds visual drama.""")
742
+ text2 = text2s[INDEX]
743
+
744
+ with torch.no_grad():
745
+ golden_embeds, shared_embeds, golden_mask = text_encoder(text2, padding='max_length')
746
+ print(golden_embeds.shape, shared_embeds.shape, ' >>>> golden ')
747
+ nopad_golden_embeds, nopad_shared_embeds, nopad_golden_mask = text_encoder(text2, padding=False)
748
+ print(golden_embeds.shape, shared_embeds.shape, ' >>>> golden ')
749
+
750
+ batch_encoding = text_encoder.t5_tokenizer(
751
+ text,
752
+ truncation=True,
753
+ max_length=text_encoder.max_length,
754
+ return_tensors="pt",
755
+ )
756
+
757
+ input_ids = batch_encoding["input_ids"][0] # Get the token IDs
758
+
759
+ # Convert token IDs back to tokens to see what they are
760
+ tokens_floral = text_encoder.t5_tokenizer.convert_ids_to_tokens(input_ids)
761
+
762
+ batch_encoding = text_encoder.t5_tokenizer(
763
+ text2,
764
+ truncation=True,
765
+ max_length=text_encoder.max_length,
766
+ return_tensors="pt",
767
+ )
768
+
769
+ input_ids = batch_encoding["input_ids"][0] # Get the token IDs
770
+ tokens_golden = text_encoder.t5_tokenizer.convert_ids_to_tokens(input_ids)
771
+
772
+
773
+ # Convert token IDs back to tokens to see what they are
774
+
775
+ # Find the index of specific words
776
+ def find_token_indices(tokens, word):
777
+ """Find all indices where a word or its token appears"""
778
+ indices = []
779
+ # T5 tokenizer might split words or add special characters
780
+ word_token = text_encoder.t5_tokenizer.encode(word, add_special_tokens=False)[0]
781
+ word_token_str = text_encoder.t5_tokenizer.convert_ids_to_tokens([word_token])[0]
782
+
783
+ for i, token in enumerate(tokens):
784
+ if token == word_token_str or word.lower() in token.lower():
785
+ indices.append(i)
786
+ return indices
787
+
788
+
789
+ key1s = ['blue', 'golden', 'blonde', 'clear', 'horse']
790
+ key2s = ['red', 'floral', 'black', 'rainy', 'red']
791
+
792
+ # Find indices for "blue"
793
+ blue_indices = find_token_indices(tokens_floral, key1s[INDEX])[-1]
794
+ print(f"Indices for 'blue': {blue_indices}")
795
+
796
+ # Find indices for "red" (won't be found in this text)
797
+ red_indices = find_token_indices(tokens_golden, key2s[INDEX])[-1]
798
+ print(f"Indices for 'red': {red_indices}")
799
+
800
+ pipe = FluxPipeline.from_pretrained("./FLUX.1-dev", dtype=torch.bfloat16, text_encoder=None).to(torch.bfloat16)
801
+ pipe.to('cuda')
802
+ adapter_path = os.path.join(CHECKPOINT_PATH, "model.safetensors")
803
+ if os.path.exists(adapter_path):
804
+ adapter_state = load_file(adapter_path)
805
+ pipe.transformer.load_state_dict(adapter_state, strict=True)
806
+ print("Transformer loaded successfully!")
807
+
808
+ images = []
809
+ empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu').to('cuda').to(torch.bfloat16)
810
+
811
+ print("Generating image with concatenation...")
812
+ images = []
813
+ # for cur_prompt_embed in [floral_embeds, nopad_floral_embeds
814
+ # , inter_embed, golden_embeds, nopad_golden_embeds]:
815
+
816
+ # for (start_dim, end_dim) in [(0,4096), (1024,4096), (2048, 4096), (1024, 2048)]:
817
+
818
+
819
+ for emb in ['floral', 'golden']:
820
+ for temp in [2.5]:
821
+ for thr in [-1, 0.5, 0.75, 0.85, 0.95]:
822
+ for topk in [None]:
823
+ print('>>>> Temperature: ', temp, topk)
824
+ if 'floral' in emb:
825
+ inter_embed, _, _, _, new_aux_info = model(floral_embeds, shared_embeds, attn_mask,
826
+ is_training=False, temperature=temp,
827
+ threshold=thr, topk=topk)
828
+ else:
829
+ inter_embed, _, _, _, new_aux_info = model(golden_embeds, shared_embeds, golden_mask,
830
+ is_training=False, temperature=temp,
831
+ threshold=thr, topk=topk)
832
+
833
+ print(new_aux_info['influence'][:, blue_indices].shape, ' >>>> influence ',
834
+ new_aux_info['influence'][:, blue_indices])
835
+ print(new_aux_info['meaningful_influence'], ' >>>> meaningful influence ',
836
+ torch.mean(new_aux_info['meaningful_influence']))
837
+
838
+ # inter_embed = torch.clone(floral_embeds)
839
+ # inter_embed[:, blue_indices] = shared_embeds[:, blue_indices]
840
+ # inter_embed[:, blue_indices, start_dim:end_dim] = floral_embeds[:, blue_indices, start_dim:end_dim]
841
+
842
+ image = pipe(
843
+ prompt_embeds=inter_embed,
844
+ pooled_prompt_embeds=empty_pooled_clip,
845
+ height=512,
846
+ width=512,
847
+ guidance_scale=3.5,
848
+ num_inference_steps=20,
849
+ max_sequence_length=512,
850
+ generator=torch.Generator("cuda").manual_seed(1995),
851
+ ).images[0]
852
+ images.append(image)
853
+ aligned_image = create_image_grid(images, cols=len(images) // 2)
854
+ os.makedirs('samples', exist_ok=True)
855
+ aligned_image.save("samples/image%s.jpg" % INDEX)
856
+
857
+
858
+
859
+
860
+
861
+
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from huggingface_hub import hf_hub_download
4
+ from safetensors.torch import load_file
5
+ from aligner import ConceptAligner
6
+ from text_encoder import LoraT5Embedder
7
+ from pipeline import CustomFluxKontextPipeline
8
+ from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
9
+ from peft import LoraConfig
10
+ import gradio as gr
11
+
12
+ # Configuration
13
+ MODEL_REPO = "Shaoan/ConceptAligner-Weights" # Your model repo
14
+ CHECKPOINT_DIR = "./checkpoint"
15
+
16
+ def download_checkpoint():
17
+ """Download checkpoint files from HF model repo"""
18
+ print("Downloading checkpoint files...")
19
+
20
+ files = [
21
+ "model.safetensors",
22
+ "model_1.safetensors",
23
+ "model_2.safetensors"
24
+ ]
25
+
26
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
27
+
28
+ for filename in files:
29
+ local_path = os.path.join(CHECKPOINT_DIR, filename)
30
+ if not os.path.exists(local_path):
31
+ print(f" Downloading {filename}...")
32
+ hf_hub_download(
33
+ repo_id=MODEL_REPO,
34
+ filename=filename,
35
+ local_dir=CHECKPOINT_DIR,
36
+ local_dir_use_symlinks=False
37
+ )
38
+ print(f" ✓ {filename} downloaded")
39
+
40
+ print("✓ All checkpoint files ready!")
41
+
42
+ class ConceptAlignerModel:
43
+ def __init__(self):
44
+ # Download checkpoint first
45
+ download_checkpoint()
46
+
47
+ self.checkpoint_path = CHECKPOINT_DIR
48
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
49
+ self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
50
+
51
+ self.previous_image = None
52
+ self.previous_prompt = None
53
+
54
+ print(f"\n{'='*60}")
55
+ print(f"Loading ConceptAligner Model")
56
+ print(f"Device: {self.device}")
57
+ print(f"{'='*60}")
58
+
59
+ self.setup_models()
60
+
61
+ def setup_models(self):
62
+ """Load all models"""
63
+ # Load ConceptAligner
64
+ print(f" Loading ConceptAligner...")
65
+ self.model = ConceptAligner().to(self.device).to(self.dtype)
66
+ adapter_path = os.path.join(self.checkpoint_path, "model_1.safetensors")
67
+ adapter_state = load_file(adapter_path)
68
+ self.model.load_state_dict(adapter_state, strict=True)
69
+ print(f" ✓ Adapter loaded")
70
+
71
+ # Load T5 encoder
72
+ print(f" Loading T5 encoder...")
73
+ self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
74
+ adapter_path = os.path.join(self.checkpoint_path, "model_2.safetensors")
75
+ adapter_state = load_file(adapter_path)
76
+ if "t5_encoder.shared.weight" in adapter_state and "t5_encoder.encoder.embed_tokens.weight" not in adapter_state:
77
+ adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
78
+ self.text_encoder.load_state_dict(adapter_state, strict=True)
79
+ print(f" ✓ T5 Adapter loaded")
80
+
81
+ # Load VAE
82
+ print(f" Loading VAE...")
83
+ vae = AutoencoderKL.from_pretrained(
84
+ 'black-forest-labs/FLUX.1-dev',
85
+ subfolder="vae",
86
+ torch_dtype=self.dtype
87
+ ).to(self.device)
88
+
89
+ # Load transformer
90
+ print(f" Loading transformer...")
91
+ transformer = FluxTransformer2DModel.from_pretrained(
92
+ 'black-forest-labs/FLUX.1-dev',
93
+ subfolder="transformer",
94
+ torch_dtype=self.dtype
95
+ )
96
+
97
+ target_modules = [
98
+ "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
99
+ "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
100
+ "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2",
101
+ "proj_mlp", "proj_out", "norm.linear", "norm1.linear"
102
+ ]
103
+
104
+ transformer_lora_config = LoraConfig(
105
+ r=256,
106
+ lora_alpha=256,
107
+ lora_dropout=0.0,
108
+ init_lora_weights="gaussian",
109
+ target_modules=target_modules,
110
+ )
111
+ transformer.add_adapter(transformer_lora_config)
112
+ transformer.context_embedder.requires_grad_(True)
113
+
114
+ # Load fine-tuned transformer
115
+ transformer_path = os.path.join(self.checkpoint_path, "model.safetensors")
116
+ transformer_state = load_file(transformer_path)
117
+ transformer.load_state_dict(transformer_state, strict=True)
118
+ print(f" ✓ Fine-tuned transformer loaded")
119
+
120
+ transformer = transformer.to(self.device)
121
+
122
+ # Load or download empty pooled clip
123
+ empty_clip_path = "empty_pooled_clip.pt"
124
+ if not os.path.exists(empty_clip_path):
125
+ print(" Downloading empty_pooled_clip.pt...")
126
+ hf_hub_download(
127
+ repo_id=MODEL_REPO,
128
+ filename="empty_pooled_clip.pt",
129
+ local_dir=".",
130
+ local_dir_use_symlinks=False
131
+ )
132
+
133
+ self.empty_pooled_clip = torch.load(empty_clip_path, map_location=self.device).to(self.dtype)
134
+
135
+ # Create pipeline
136
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
137
+ 'black-forest-labs/FLUX.1-dev', subfolder="scheduler"
138
+ )
139
+
140
+ self.pipe = CustomFluxKontextPipeline(
141
+ scheduler=noise_scheduler,
142
+ aligner=self.model.to(self.device).to(self.dtype),
143
+ transformer=transformer.to(self.device).to(self.dtype),
144
+ vae=vae.to(self.device).to(self.dtype),
145
+ text_embedder=self.text_encoder.to(self.device).to(self.dtype),
146
+ ).to(self.device)
147
+
148
+ if torch.cuda.is_available():
149
+ allocated = torch.cuda.memory_allocated(0) / 1024**3
150
+ reserved = torch.cuda.memory_reserved(0) / 1024**3
151
+ print(f" ✓ Pipeline ready on {self.device}")
152
+ print(f" 📊 GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
153
+ else:
154
+ print(f" ✓ Pipeline ready on {self.device}")
155
+
156
+ @torch.no_grad()
157
+ def generate_image(
158
+ self,
159
+ prompt,
160
+ threshold=0.0,
161
+ topk=0,
162
+ height=512,
163
+ width=512,
164
+ guidance_scale=3.5,
165
+ true_cf_scale=1.0,
166
+ num_inference_steps=20,
167
+ seed=1995
168
+ ):
169
+ """Generate image and return previous + current for comparison"""
170
+ if not prompt.strip():
171
+ return self.previous_image, None, self.previous_prompt or ""
172
+
173
+ try:
174
+ generator = torch.Generator(device=self.device).manual_seed(int(seed))
175
+
176
+ current_image = self.pipe(
177
+ prompt=prompt,
178
+ guidance_scale=guidance_scale,
179
+ true_cfg_scale=true_cf_scale,
180
+ max_sequence_length=512,
181
+ num_inference_steps=num_inference_steps,
182
+ height=height,
183
+ width=width,
184
+ generator=generator,
185
+ ).images[0]
186
+
187
+ prev_image = self.previous_image
188
+ prev_prompt = self.previous_prompt or "No previous generation"
189
+
190
+ self.previous_image = current_image
191
+ self.previous_prompt = prompt
192
+
193
+ return prev_image, current_image, prev_prompt
194
+
195
+ except Exception as e:
196
+ import traceback
197
+ error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
198
+ print(error_msg)
199
+ return self.previous_image, None, self.previous_prompt or ""
200
+
201
+ def reset_history(self):
202
+ """Clear generation history"""
203
+ self.previous_image = None
204
+ self.previous_prompt = None
205
+ return None, None, "No previous generation"
206
+
207
+
208
+ # Initialize model
209
+ print("Initializing ConceptAligner model...")
210
+ model = ConceptAlignerModel()
empty_pooled_clip.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92acbe688a00c835deb9b645fe673e16af2ceef9cd749a8b838e67dea23d76b2
3
+ size 3183
pipeline.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import inspect
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+ import numpy as np
5
+ import torch
6
+ from transformers import (
7
+ CLIPImageProcessor,
8
+ CLIPTextModel,
9
+ CLIPTokenizer,
10
+ CLIPVisionModelWithProjection,
11
+ T5EncoderModel,
12
+ T5TokenizerFast,
13
+ )
14
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
15
+ from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
16
+ from diffusers.models import AutoencoderKL, FluxTransformer2DModel
17
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
18
+ from diffusers.utils import (
19
+ USE_PEFT_BACKEND,
20
+ deprecate,
21
+ is_torch_xla_available,
22
+ logging,
23
+ replace_example_docstring,
24
+ scale_lora_layers,
25
+ unscale_lora_layers,
26
+ )
27
+ from diffusers.utils.torch_utils import randn_tensor
28
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
29
+ from diffusers import FluxKontextPipeline
30
+
31
+ PREFERRED_KONTEXT_RESOLUTIONS = [
32
+ (672, 1568),
33
+ (688, 1504),
34
+ (720, 1456),
35
+ (752, 1392),
36
+ (800, 1328),
37
+ (832, 1248),
38
+ (880, 1184),
39
+ (944, 1104),
40
+ (1024, 1024),
41
+ (1104, 944),
42
+ (1184, 880),
43
+ (1248, 832),
44
+ (1328, 800),
45
+ (1392, 752),
46
+ (1456, 720),
47
+ (1504, 688),
48
+ (1568, 672),
49
+ ]
50
+
51
+
52
+ def calculate_shift(
53
+ image_seq_len,
54
+ base_seq_len: int = 256,
55
+ max_seq_len: int = 4096,
56
+ base_shift: float = 0.5,
57
+ max_shift: float = 1.15,
58
+ ):
59
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
60
+ b = base_shift - m * base_seq_len
61
+ mu = image_seq_len * m + b
62
+ return mu
63
+
64
+
65
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
66
+ def retrieve_timesteps(
67
+ scheduler,
68
+ num_inference_steps: Optional[int] = None,
69
+ device: Optional[Union[str, torch.device]] = None,
70
+ timesteps: Optional[List[int]] = None,
71
+ sigmas: Optional[List[float]] = None,
72
+ **kwargs,
73
+ ):
74
+ r"""
75
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
76
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
77
+
78
+ Args:
79
+ scheduler (`SchedulerMixin`):
80
+ The scheduler to get timesteps from.
81
+ num_inference_steps (`int`):
82
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
83
+ must be `None`.
84
+ device (`str` or `torch.device`, *optional*):
85
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
86
+ timesteps (`List[int]`, *optional*):
87
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
88
+ `num_inference_steps` and `sigmas` must be `None`.
89
+ sigmas (`List[float]`, *optional*):
90
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
91
+ `num_inference_steps` and `timesteps` must be `None`.
92
+
93
+ Returns:
94
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
95
+ second element is the number of inference steps.
96
+ """
97
+ if timesteps is not None and sigmas is not None:
98
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
99
+ if timesteps is not None:
100
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
101
+ if not accepts_timesteps:
102
+ raise ValueError(
103
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
104
+ f" timestep schedules. Please check whether you are using the correct scheduler."
105
+ )
106
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
107
+ timesteps = scheduler.timesteps
108
+ num_inference_steps = len(timesteps)
109
+ elif sigmas is not None:
110
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
111
+ if not accept_sigmas:
112
+ raise ValueError(
113
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
114
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
115
+ )
116
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
117
+ timesteps = scheduler.timesteps
118
+ num_inference_steps = len(timesteps)
119
+ else:
120
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
121
+ timesteps = scheduler.timesteps
122
+ return timesteps, num_inference_steps
123
+
124
+
125
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
126
+ def retrieve_latents(
127
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
128
+ ):
129
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
130
+ return encoder_output.latent_dist.sample(generator)
131
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
132
+ return encoder_output.latent_dist.mode()
133
+ elif hasattr(encoder_output, "latents"):
134
+ return encoder_output.latents
135
+ else:
136
+ raise AttributeError("Could not access latents of provided encoder_output")
137
+
138
+
139
+ from diffusers import FluxKontextPipeline
140
+ from typing import Union, List, Optional
141
+ import torch
142
+
143
+
144
+ class CustomFluxKontextPipeline(FluxKontextPipeline):
145
+ r"""
146
+ Custom Flux Kontext pipeline with a wrapper text embedder.
147
+ """
148
+
149
+ model_cpu_offload_seq = "text_embedder->image_encoder->transformer->vae"
150
+
151
+ def __init__(
152
+ self,
153
+ scheduler,
154
+ vae,
155
+ text_embedder, # Your custom text embedder wrapper
156
+ transformer,
157
+ aligner,
158
+ image_encoder=None,
159
+ feature_extractor=None,
160
+ ):
161
+ # Don't call super().__init__() since parent expects text_encoder parameters
162
+ # Instead, manually register modules
163
+ from diffusers import DiffusionPipeline
164
+ DiffusionPipeline.__init__(self)
165
+
166
+ self.register_modules(
167
+ vae=vae,
168
+ text_embedder=text_embedder,
169
+ transformer=transformer,
170
+ scheduler=scheduler,
171
+ aligner=aligner,
172
+ image_encoder=image_encoder,
173
+ feature_extractor=feature_extractor,
174
+ )
175
+
176
+ # Initialize the necessary attributes from parent
177
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
178
+ self.latent_channels = self.vae.config.latent_channels
179
+ from diffusers.image_processor import VaeImageProcessor
180
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
181
+ self.default_sample_size = 128
182
+
183
+ def encode_prompt(
184
+ self,
185
+ prompt: Union[str, List[str]],
186
+ prompt_2: Optional[Union[str, List[str]]] = None,
187
+ device: Optional[torch.device] = None,
188
+ num_images_per_prompt: int = 1,
189
+ prompt_embeds: Optional[torch.FloatTensor] = None,
190
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
191
+ max_sequence_length: int = 512,
192
+ lora_scale: Optional[float] = None,
193
+ temperature=None,
194
+ threshold=None,
195
+ ):
196
+ device = device or self._execution_device
197
+
198
+ if prompt_embeds is None:
199
+ # Use your custom text embedder
200
+ qwen_embeds, clip_image_embeds, perturbed_qwen_embeds, replace_ids, t5_tokenizer, batch_encoding = self.text_embedder(prompt)
201
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds, text_ids, _ = self.aligner(qwen_embeds,
202
+ )
203
+ prompt_embeds = prompt_embeds.to(device=device)
204
+ pooled_prompt_embeds = pooled_prompt_embeds.to(device=device)
205
+ text_ids = text_ids.to(device=device)
206
+ else:
207
+ # When embeddings are provided, create text_ids
208
+ dtype = self.transformer.dtype
209
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
210
+
211
+ # Duplicate for num_images_per_prompt
212
+ if num_images_per_prompt > 1:
213
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
214
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
215
+ text_ids = text_ids.repeat(num_images_per_prompt, 1)
216
+
217
+ return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds, text_ids
218
+
219
+ @torch.no_grad()
220
+ def __call__(
221
+ self,
222
+ image: Optional[PipelineImageInput] = None,
223
+ prompt: Union[str, List[str]] = None,
224
+ prompt_2: Optional[Union[str, List[str]]] = None,
225
+ negative_prompt: Union[str, List[str]] = "",
226
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
227
+ true_cfg_scale: float = 1.0,
228
+ height: Optional[int] = None,
229
+ width: Optional[int] = None,
230
+ num_inference_steps: int = 28,
231
+ sigmas: Optional[List[float]] = None,
232
+ guidance_scale: float = 3.5,
233
+ num_images_per_prompt: Optional[int] = 1,
234
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
235
+ latents: Optional[torch.FloatTensor] = None,
236
+ prompt_embeds: Optional[torch.FloatTensor] = None,
237
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
238
+ ip_adapter_image: Optional[PipelineImageInput] = None,
239
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
240
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
241
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
242
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
243
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
244
+ output_type: Optional[str] = "pil",
245
+ return_dict: bool = True,
246
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
247
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
248
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
249
+ max_sequence_length: int = 512,
250
+ max_area: int = 1024 ** 2,
251
+ _auto_resize: bool = True,
252
+ temperature=None,
253
+ threshold=None,
254
+ ):
255
+ r"""
256
+ Function invoked when calling the pipeline for generation.
257
+
258
+ Args:
259
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
260
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
261
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
262
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
263
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
264
+ latents as `image`, but if passing latents directly it is not encoded again.
265
+ prompt (`str` or `List[str]`, *optional*):
266
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
267
+ instead.
268
+ prompt_2 (`str` or `List[str]`, *optional*):
269
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
270
+ will be used instead.
271
+ negative_prompt (`str` or `List[str]`, *optional*):
272
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
273
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
274
+ not greater than `1`).
275
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
276
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
277
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
278
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
279
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
280
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
281
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
282
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
283
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
284
+ num_inference_steps (`int`, *optional*, defaults to 50):
285
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
286
+ expense of slower inference.
287
+ sigmas (`List[float]`, *optional*):
288
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
289
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
290
+ will be used.
291
+ guidance_scale (`float`, *optional*, defaults to 3.5):
292
+ Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
293
+ a model to generate images more aligned with prompt at the expense of lower image quality.
294
+
295
+ Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to
296
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
297
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
298
+ The number of images to generate per prompt.
299
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
300
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
301
+ to make generation deterministic.
302
+ latents (`torch.FloatTensor`, *optional*):
303
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
304
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
305
+ tensor will be generated by sampling using the supplied random `generator`.
306
+ prompt_embeds (`torch.FloatTensor`, *optional*):
307
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
308
+ provided, text embeddings will be generated from `prompt` input argument.
309
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
310
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
311
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
312
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
313
+ Optional image input to work with IP Adapters.
314
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
315
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
316
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
317
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
318
+ negative_ip_adapter_image:
319
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
320
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
321
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
322
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
323
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
324
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
325
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
326
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
327
+ argument.
328
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
329
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
330
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
331
+ input argument.
332
+ output_type (`str`, *optional*, defaults to `"pil"`):
333
+ The output format of the generate image. Choose between
334
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
335
+ return_dict (`bool`, *optional*, defaults to `True`):
336
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
337
+ joint_attention_kwargs (`dict`, *optional*):
338
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
339
+ `self.processor` in
340
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
341
+ callback_on_step_end (`Callable`, *optional*):
342
+ A function that calls at the end of each denoising steps during the inference. The function is called
343
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
344
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
345
+ `callback_on_step_end_tensor_inputs`.
346
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
347
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
348
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
349
+ `._callback_tensor_inputs` attribute of your pipeline class.
350
+ max_sequence_length (`int` defaults to 512):
351
+ Maximum sequence length to use with the `prompt`.
352
+ max_area (`int`, defaults to `1024 ** 2`):
353
+ The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
354
+ area while maintaining the aspect ratio.
355
+
356
+ Examples:
357
+
358
+ Returns:
359
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
360
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
361
+ images.
362
+ """
363
+
364
+ height = height or self.default_sample_size * self.vae_scale_factor
365
+ width = width or self.default_sample_size * self.vae_scale_factor
366
+
367
+ original_height, original_width = height, width
368
+ aspect_ratio = width / height
369
+
370
+ """
371
+ width = round((max_area * aspect_ratio) ** 0.5)
372
+ height = round((max_area / aspect_ratio) ** 0.5)
373
+ multiple_of = self.vae_scale_factor * 2
374
+ width = width // multiple_of * multiple_of
375
+ height = height // multiple_of * multiple_of
376
+
377
+ if height != original_height or width != original_width:
378
+ print(
379
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
380
+ )
381
+ """
382
+
383
+ # 1. Check inputs. Raise error if not correct
384
+ self.check_inputs(
385
+ prompt,
386
+ prompt_2,
387
+ height,
388
+ width,
389
+ negative_prompt=negative_prompt,
390
+ negative_prompt_2=negative_prompt_2,
391
+ prompt_embeds=prompt_embeds,
392
+ negative_prompt_embeds=negative_prompt_embeds,
393
+ pooled_prompt_embeds=pooled_prompt_embeds,
394
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
395
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
396
+ max_sequence_length=max_sequence_length,
397
+ )
398
+
399
+ self._guidance_scale = guidance_scale
400
+ self._joint_attention_kwargs = joint_attention_kwargs
401
+ self._current_timestep = None
402
+ self._interrupt = False
403
+
404
+ # 2. Define call parameters
405
+ if prompt is not None and isinstance(prompt, str):
406
+ batch_size = 1
407
+ elif prompt is not None and isinstance(prompt, list):
408
+ batch_size = len(prompt)
409
+ else:
410
+ batch_size = prompt_embeds.shape[0]
411
+
412
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
413
+
414
+ lora_scale = (
415
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
416
+ )
417
+ has_neg_prompt = negative_prompt is not None or (
418
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
419
+ )
420
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
421
+ (
422
+ prompt_embeds,
423
+ prompt_attention_mask,
424
+ pooled_prompt_embeds,
425
+ text_ids,
426
+ ) = self.encode_prompt(
427
+ prompt=prompt,
428
+ prompt_2=prompt_2,
429
+ prompt_embeds=prompt_embeds,
430
+ pooled_prompt_embeds=pooled_prompt_embeds,
431
+ device=device,
432
+ num_images_per_prompt=num_images_per_prompt,
433
+ max_sequence_length=max_sequence_length,
434
+ lora_scale=lora_scale,
435
+ temperature=temperature,
436
+ threshold=threshold,
437
+ )
438
+ (
439
+ negative_prompt_embeds,
440
+ negative_prompt_attention_mask,
441
+ negative_pooled_prompt_embeds,
442
+ negative_text_ids,
443
+ ) = self.encode_prompt(
444
+ prompt=negative_prompt,
445
+ prompt_2=negative_prompt_2,
446
+ prompt_embeds=negative_prompt_embeds,
447
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
448
+ device=device,
449
+ num_images_per_prompt=num_images_per_prompt,
450
+ max_sequence_length=max_sequence_length,
451
+ lora_scale=lora_scale,
452
+ temperature=temperature,
453
+ threshold=threshold,
454
+ )
455
+
456
+ pooled_prompt_embeds = negative_pooled_prompt_embeds
457
+
458
+ # 3. Preprocess image
459
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
460
+ img = image[0] if isinstance(image, list) else image
461
+ """
462
+ image_height, image_width = self.image_processor.get_default_height_width(img)
463
+ aspect_ratio = image_width / image_height
464
+ if _auto_resize:
465
+ # Kontext is trained on specific resolutions, using one of them is recommended
466
+ _, image_width, image_height = min(
467
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
468
+ )
469
+ image_width = image_width // multiple_of * multiple_of
470
+ image_height = image_height // multiple_of * multiple_of
471
+ """
472
+ image_height, image_width = original_height, original_width
473
+ image = self.image_processor.resize(image, image_height, image_width)
474
+ image = self.image_processor.preprocess(image, image_height, image_width)
475
+
476
+ # 4. Prepare latent variables
477
+ num_channels_latents = self.transformer.config.in_channels // 4
478
+ latents, image_latents, latent_ids, image_ids = self.prepare_latents(
479
+ image,
480
+ batch_size * num_images_per_prompt,
481
+ num_channels_latents,
482
+ height,
483
+ width,
484
+ prompt_embeds.dtype,
485
+ device,
486
+ generator,
487
+ latents,
488
+ )
489
+ if image_ids is not None:
490
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
491
+
492
+ # 5. Prepare timesteps
493
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
494
+ image_seq_len = latents.shape[1]
495
+ mu = calculate_shift(
496
+ image_seq_len,
497
+ self.scheduler.config.get("base_image_seq_len", 256),
498
+ self.scheduler.config.get("max_image_seq_len", 4096),
499
+ self.scheduler.config.get("base_shift", 0.5),
500
+ self.scheduler.config.get("max_shift", 1.15),
501
+ )
502
+ timesteps, num_inference_steps = retrieve_timesteps(
503
+ self.scheduler,
504
+ num_inference_steps,
505
+ device,
506
+ sigmas=sigmas,
507
+ mu=mu,
508
+ )
509
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
510
+ self._num_timesteps = len(timesteps)
511
+
512
+ # handle guidance
513
+ if self.transformer.config.guidance_embeds:
514
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
515
+ guidance = guidance.expand(latents.shape[0])
516
+ else:
517
+ guidance = None
518
+
519
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
520
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
521
+ ):
522
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
523
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
524
+
525
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
526
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
527
+ ):
528
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
529
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
530
+
531
+ if self.joint_attention_kwargs is None:
532
+ self._joint_attention_kwargs = {}
533
+
534
+ image_embeds = None
535
+ negative_image_embeds = None
536
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
537
+ image_embeds = self.prepare_ip_adapter_image_embeds(
538
+ ip_adapter_image,
539
+ ip_adapter_image_embeds,
540
+ device,
541
+ batch_size * num_images_per_prompt,
542
+ )
543
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
544
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
545
+ negative_ip_adapter_image,
546
+ negative_ip_adapter_image_embeds,
547
+ device,
548
+ batch_size * num_images_per_prompt,
549
+ )
550
+
551
+ # 6. Denoising loop
552
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
553
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
554
+ self.scheduler.set_begin_index(0)
555
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
556
+ for i, t in enumerate(timesteps):
557
+ if self.interrupt:
558
+ continue
559
+
560
+ self._current_timestep = t
561
+ if image_embeds is not None:
562
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
563
+
564
+ latent_model_input = latents
565
+ if image_latents is not None:
566
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
567
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
568
+
569
+ noise_pred = self.transformer(
570
+ hidden_states=latent_model_input,
571
+ timestep=timestep / 1000,
572
+ guidance=guidance,
573
+ pooled_projections=pooled_prompt_embeds,
574
+ encoder_hidden_states=prompt_embeds,
575
+ txt_ids=text_ids,
576
+ img_ids=latent_ids,
577
+ joint_attention_kwargs={'attention_mask': prompt_attention_mask},
578
+ return_dict=False,
579
+ )[0]
580
+ noise_pred = noise_pred[:, : latents.size(1)]
581
+
582
+ if do_true_cfg:
583
+ if negative_image_embeds is not None:
584
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
585
+ neg_noise_pred = self.transformer(
586
+ hidden_states=latent_model_input,
587
+ timestep=timestep / 1000,
588
+ guidance=guidance,
589
+ pooled_projections=negative_pooled_prompt_embeds,
590
+ encoder_hidden_states=negative_prompt_embeds,
591
+ txt_ids=negative_text_ids,
592
+ img_ids=latent_ids,
593
+ joint_attention_kwargs={'attention_mask': negative_prompt_attention_mask},
594
+ return_dict=False,
595
+ )[0]
596
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
597
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
598
+
599
+ # compute the previous noisy sample x_t -> x_t-1
600
+ latents_dtype = latents.dtype
601
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
602
+
603
+ if latents.dtype != latents_dtype:
604
+ if torch.backends.mps.is_available():
605
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
606
+ latents = latents.to(latents_dtype)
607
+
608
+ if callback_on_step_end is not None:
609
+ callback_kwargs = {}
610
+ for k in callback_on_step_end_tensor_inputs:
611
+ callback_kwargs[k] = locals()[k]
612
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
613
+
614
+ latents = callback_outputs.pop("latents", latents)
615
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
616
+
617
+ # call the callback, if provided
618
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
619
+ progress_bar.update()
620
+
621
+ self._current_timestep = None
622
+
623
+ if output_type == "latent":
624
+ image = latents
625
+ else:
626
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
627
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
628
+
629
+ dtype = torch.bfloat16
630
+ image = self.vae.decode(latents.to(dtype), return_dict=False)[0]
631
+ image = self.image_processor.postprocess(image, output_type=output_type)
632
+
633
+ # Offload all models
634
+ self.maybe_free_model_hooks()
635
+
636
+ if not return_dict:
637
+ return (image,)
638
+
639
+ return FluxPipelineOutput(images=image)
640
+
641
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ gradio>=4.0.0
3
+ diffusers>=0.27.0
4
+ transformers>=4.38.0
5
+ safetensors>=0.4.0
6
+ accelerate>=0.26.0
7
+ peft>=0.8.0
8
+ Pillow>=10.0.0
requirements.txt.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ gradio>=4.0.0
3
+ diffusers>=0.27.0
4
+ transformers>=4.38.0
5
+ safetensors>=0.4.0
6
+ accelerate>=0.26.0
7
+ peft>=0.8.0
8
+ Pillow>=10.0.0
text_encoder.py ADDED
@@ -0,0 +1,1188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ import torch.nn as nn
5
+
6
+
7
+ def tokenize_prompt(tokenizer, prompt, max_sequence_length):
8
+ text_inputs = tokenizer(
9
+ prompt,
10
+ padding="max_length",
11
+ max_length=max_sequence_length,
12
+ truncation=True,
13
+ return_length=False,
14
+ return_overflowing_tokens=False,
15
+ return_tensors="pt",
16
+ )
17
+ text_input_ids = text_inputs.input_ids
18
+ return text_input_ids
19
+
20
+
21
+ def _encode_prompt_with_t5(
22
+ text_encoder,
23
+ tokenizer,
24
+ max_sequence_length=512,
25
+ prompt=None,
26
+ num_images_per_prompt=1,
27
+ device=None,
28
+ text_input_ids=None,
29
+ ):
30
+ prompt = [prompt] if isinstance(prompt, str) else prompt
31
+ batch_size = len(prompt)
32
+
33
+ if tokenizer is not None:
34
+ text_inputs = tokenizer(
35
+ prompt,
36
+ padding="max_length",
37
+ max_length=max_sequence_length,
38
+ truncation=True,
39
+ return_length=False,
40
+ return_overflowing_tokens=False,
41
+ return_tensors="pt",
42
+ )
43
+ text_input_ids = text_inputs.input_ids
44
+ else:
45
+ if text_input_ids is None:
46
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
47
+
48
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
49
+
50
+ if hasattr(text_encoder, "module"):
51
+ dtype = text_encoder.module.dtype
52
+ else:
53
+ dtype = text_encoder.dtype
54
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
55
+
56
+ _, seq_len, _ = prompt_embeds.shape
57
+
58
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
59
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
60
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
61
+
62
+ return prompt_embeds
63
+
64
+
65
+ def _encode_prompt_with_clip(
66
+ text_encoder,
67
+ tokenizer,
68
+ prompt: str,
69
+ device=None,
70
+ text_input_ids=None,
71
+ num_images_per_prompt: int = 1,
72
+ ):
73
+ prompt = [prompt] if isinstance(prompt, str) else prompt
74
+ batch_size = len(prompt)
75
+
76
+ if tokenizer is not None:
77
+ text_inputs = tokenizer(
78
+ prompt,
79
+ padding="max_length",
80
+ max_length=77,
81
+ truncation=True,
82
+ return_overflowing_tokens=False,
83
+ return_length=False,
84
+ return_tensors="pt",
85
+ )
86
+
87
+ text_input_ids = text_inputs.input_ids
88
+ else:
89
+ if text_input_ids is None:
90
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
91
+
92
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
93
+
94
+ if hasattr(text_encoder, "module"):
95
+ dtype = text_encoder.module.dtype
96
+ else:
97
+ dtype = text_encoder.dtype
98
+ # Use pooled output of CLIPTextModel
99
+ prompt_embeds = prompt_embeds.pooler_output
100
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
101
+
102
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
103
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
104
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
105
+
106
+ return prompt_embeds
107
+
108
+
109
+ def encode_prompt(
110
+ text_encoders,
111
+ tokenizers,
112
+ prompt: str,
113
+ max_sequence_length,
114
+ device=None,
115
+ num_images_per_prompt: int = 1,
116
+ text_input_ids_list=None,
117
+ ):
118
+ prompt = [prompt] if isinstance(prompt, str) else prompt
119
+
120
+ if hasattr(text_encoders[0], "module"):
121
+ dtype = text_encoders[0].module.dtype
122
+ else:
123
+ dtype = text_encoders[0].dtype
124
+
125
+ pooled_prompt_embeds = _encode_prompt_with_clip(
126
+ text_encoder=text_encoders[0],
127
+ tokenizer=tokenizers[0],
128
+ prompt=prompt,
129
+ device=device if device is not None else text_encoders[0].device,
130
+ num_images_per_prompt=num_images_per_prompt,
131
+ text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
132
+ )
133
+
134
+ prompt_embeds = _encode_prompt_with_t5(
135
+ text_encoder=text_encoders[1],
136
+ tokenizer=tokenizers[1],
137
+ max_sequence_length=max_sequence_length,
138
+ prompt=prompt,
139
+ num_images_per_prompt=num_images_per_prompt,
140
+ device=device if device is not None else text_encoders[1].device,
141
+ text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
142
+ )
143
+
144
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
145
+
146
+ return prompt_embeds, pooled_prompt_embeds, text_ids
147
+
148
+
149
+ from transformers import T5EncoderModel, T5Tokenizer, CLIPTokenizer, CLIPTextModel
150
+ class T5Embedder(torch.nn.Module):
151
+ def __init__(self, device, max_length=300):
152
+ super().__init__()
153
+ self.device = device
154
+ self.max_length = max_length
155
+ dtype = torch.bfloat16
156
+ self.dtype = dtype
157
+ t5_version = './t5-v1_1-xxl'
158
+ self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length)
159
+ self.t5_encoder = T5EncoderModel.from_pretrained(t5_version, torch_dtype=dtype).to(device=device)
160
+ self.t5_encoder = self.t5_encoder.eval().requires_grad_(False)
161
+ self.num_shared = max_length
162
+
163
+ @torch.no_grad()
164
+ def forward(self, text):
165
+ if isinstance(text, str):
166
+ text = [text]
167
+ batch_encoding = self.t5_tokenizer(
168
+ text,
169
+ truncation=True,
170
+ max_length=self.max_length,
171
+ return_length=False,
172
+ return_overflowing_tokens=False,
173
+ padding="max_length",
174
+ return_tensors="pt",
175
+ )
176
+
177
+ prompt_embeds = self.t5_encoder(
178
+ input_ids=batch_encoding["input_ids"].to(self.device),
179
+ attention_mask=None,
180
+ output_hidden_states=False,
181
+ )['last_hidden_state']
182
+ prompt_attention_mask = batch_encoding['attention_mask'].to(self.device)
183
+
184
+
185
+ new_text = [x.split('.')[0] for x in text]
186
+ batch_encoding = self.t5_tokenizer(
187
+ new_text,
188
+ truncation=True,
189
+ max_length=self.num_shared,
190
+ return_length=False,
191
+ return_overflowing_tokens=False,
192
+ padding="max_length",
193
+ return_tensors="pt",
194
+ )
195
+ shared_prompt_embeds = self.t5_encoder(
196
+ input_ids=batch_encoding["input_ids"].to(self.device),
197
+ attention_mask=None,
198
+ output_hidden_states=False,
199
+ )['last_hidden_state']
200
+
201
+ return prompt_embeds, shared_prompt_embeds, prompt_attention_mask
202
+
203
+
204
+
205
+
206
+ import random
207
+
208
+ from torch.utils.checkpoint import checkpoint
209
+ from peft import LoraConfig, set_peft_model_state_dict
210
+ class LoraT5EmbedderNoGradientCheck(torch.nn.Module):
211
+ def __init__(self, device, rank=64, max_length=300):
212
+ super().__init__()
213
+ self.device = device
214
+ self.max_length = max_length
215
+ dtype = torch.bfloat16
216
+ self.dtype = dtype
217
+ t5_version = './t5-v1_1-xxl'
218
+ self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length)
219
+ self.t5_encoder = T5EncoderModel.from_pretrained(t5_version, torch_dtype=dtype).to(device=device).to(dtype)
220
+ self.t5_encoder.gradient_checkpointing_enable()
221
+ self.t5_encoder.config.gradient_checkpointing = True
222
+ self.t5_encoder.requires_grad_(False)
223
+ self.t5_encoder.eval()
224
+ # Add LoRA adapters to the T5 model
225
+ text_lora_config = LoraConfig(
226
+ r=rank,
227
+ lora_alpha=rank,
228
+ lora_dropout=0.0,
229
+ init_lora_weights="gaussian",
230
+ target_modules=["SelfAttention.q", "SelfAttention.k", "SelfAttention.v", "SelfAttention.o", "DenseReluDense.wi", "DenseReluDense.wo"],
231
+ )
232
+ self.t5_encoder.add_adapter(text_lora_config)
233
+ #self.t5_encoder.encoder.embed_tokens.weight.requires_grad = True
234
+ print(f"Gradient checkpointing enabled: {self.t5_encoder.is_gradient_checkpointing}")
235
+
236
+ image_encoder_path = 'openai/clip-vit-large-patch14'
237
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(device=device).to(torch.bfloat16)
238
+ self.image_encoder = self.image_encoder.eval().requires_grad_(False)
239
+
240
+ def compute_perturbation_loss(self, prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding):
241
+ """
242
+ Compute group lasso for non-pad non-change tokens, L1 for change tokens,
243
+ and group sparsity for pad non-change tokens.
244
+
245
+ Args:
246
+ prompt_embeds: Original embeddings [batch_size, seq_len, hidden_dim]
247
+ perturbed_prompt_embeds: Perturbed embeddings [batch_size, seq_len, hidden_dim]
248
+ replaced_ids: List of replaced token indices for each sample in batch
249
+ batch_encoding: The tokenizer output containing input_ids
250
+
251
+ Returns:
252
+ l2_loss: Group lasso loss for non-pad non-change tokens (scalar tensor)
253
+ l1_loss: L1 loss for change tokens (scalar tensor)
254
+ pad_group_loss: Group sparsity loss for pad non-change tokens (scalar tensor)
255
+ """
256
+ batch_size = prompt_embeds.size(0)
257
+ pad_token_id = self.t5_tokenizer.pad_token_id
258
+ input_ids = batch_encoding["input_ids"]
259
+
260
+ l2_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
261
+ l1_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
262
+ pad_group_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
263
+
264
+ # Track valid samples for each loss type separately
265
+ l1_valid_samples = 0
266
+ l2_valid_samples = 0
267
+ pad_valid_samples = 0
268
+
269
+ for i in range(batch_size):
270
+ # Get the replaced index for this sample
271
+ replaced_idx = replaced_ids[i]
272
+
273
+ if replaced_idx is None:
274
+ # No replacement happened (all padding), skip
275
+ continue
276
+
277
+ # Find padding and non-padding token indices
278
+ pad_mask = input_ids[i] == pad_token_id
279
+ non_pad_mask = ~pad_mask
280
+
281
+ pad_indices = torch.where(pad_mask)[0]
282
+ non_pad_indices = torch.where(non_pad_mask)[0]
283
+
284
+ # Filter out the replaced index from non-padding indices (non-pad non-change)
285
+ non_selected_non_pad_indices = non_pad_indices[non_pad_indices != replaced_idx]
286
+
287
+ # Compute L1 loss on selected (replaced) index - CHANGE TOKEN
288
+ selected_diff = prompt_embeds[i, replaced_idx] - perturbed_prompt_embeds[i, replaced_idx]
289
+ l1_loss_total = l1_loss_total + torch.abs(selected_diff).mean()
290
+ l1_valid_samples += 1
291
+
292
+ # Compute group lasso (L2) loss on NON-PAD NON-CHANGE tokens
293
+ if len(non_selected_non_pad_indices) > 0:
294
+ non_selected_diff = prompt_embeds[i, non_selected_non_pad_indices] - perturbed_prompt_embeds[
295
+ i, non_selected_non_pad_indices]
296
+ l2_per_token = torch.sqrt((non_selected_diff ** 2).sum(dim=1))
297
+ l2_loss_total = l2_loss_total + l2_per_token.mean()
298
+ l2_valid_samples += 1
299
+
300
+ # Compute group sparsity loss on PAD NON-CHANGE tokens
301
+ if len(pad_indices) > 0:
302
+ pad_diff = prompt_embeds[i, pad_indices] - perturbed_prompt_embeds[i, pad_indices]
303
+ # Group sparsity: L2 norm per token (encourages entire token embeddings to be zero)
304
+ pad_group_per_token = torch.sqrt((pad_diff ** 2).sum(dim=1))
305
+ pad_group_loss_total = pad_group_loss_total + pad_group_per_token.mean()
306
+ pad_valid_samples += 1
307
+
308
+ # Average over valid samples for each loss type
309
+ l2_loss = l2_loss_total / l2_valid_samples if l2_valid_samples > 0 else torch.tensor(0.0,
310
+ device=prompt_embeds.device)
311
+ l1_loss = l1_loss_total / l1_valid_samples if l1_valid_samples > 0 else torch.tensor(0.0,
312
+ device=prompt_embeds.device)
313
+ pad_group_loss = pad_group_loss_total / pad_valid_samples if pad_valid_samples > 0 else torch.tensor(0.0,
314
+ device=prompt_embeds.device)
315
+
316
+ return l2_loss, l1_loss, pad_group_loss
317
+
318
+
319
+
320
+
321
+ def forward(self, text, image=None):
322
+ if isinstance(text, str):
323
+ text = [text]
324
+ batch_encoding = self.t5_tokenizer(
325
+ text,
326
+ truncation=True,
327
+ max_length=self.max_length,
328
+ return_length=False,
329
+ return_overflowing_tokens=False,
330
+ padding="max_length",
331
+ return_tensors="pt",
332
+ )
333
+ prompt_embeds = self.t5_encoder(
334
+ input_ids=batch_encoding["input_ids"].to(self.device),
335
+ attention_mask=None,
336
+ output_hidden_states=False,
337
+ )['last_hidden_state']
338
+
339
+ # Get input_ids and create a copy to modify
340
+ input_ids = batch_encoding["input_ids"].clone()
341
+ batch_size = input_ids.size(0)
342
+
343
+ # Get the padding token id
344
+ pad_token_id = self.t5_tokenizer.pad_token_id
345
+
346
+ replaced_ids = []
347
+ # For each sample in the batch
348
+ for i in range(batch_size):
349
+ # Find indices of non-padding tokens
350
+ non_pad_mask = input_ids[i] != pad_token_id
351
+ non_pad_indices = torch.where(non_pad_mask)[0]
352
+
353
+ # If there are meaningful tokens, randomly select one to replace
354
+ if len(non_pad_indices) > 0:
355
+ # Randomly select an index from non-padding tokens
356
+ random_idx = non_pad_indices[random.randint(0, len(non_pad_indices) - 1)]
357
+ # Replace with padding token
358
+ input_ids[i, random_idx] = pad_token_id
359
+ replaced_ids.append(random_idx.item())
360
+ else:
361
+ replaced_ids.append(None) # No replacement if all tokens are padding
362
+
363
+
364
+ perturbed_prompt_embeds = self.t5_encoder(
365
+ input_ids=input_ids.to(self.device),
366
+ attention_mask=None,
367
+ output_hidden_states=False,
368
+ )['last_hidden_state']
369
+
370
+ l2_loss, l1_loss, pad_loss = self.compute_perturbation_loss(
371
+ prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding
372
+ )
373
+
374
+ with torch.no_grad():
375
+ if image is not None:
376
+ clip_image_embeds = self.image_encoder(image.to(self.device)).image_embeds
377
+ else:
378
+ clip_image_embeds = None
379
+
380
+
381
+ return prompt_embeds, l2_loss, l1_loss, pad_loss,clip_image_embeds
382
+
383
+
384
+ from peft import LoraConfig, set_peft_model_state_dict
385
+ import torch.utils.checkpoint as checkpoint
386
+ from transformers import CLIPVisionModelWithProjection
387
+
388
+ class LoraT5Embedder(torch.nn.Module):
389
+ def __init__(self, device, rank=128, max_length=300, use_gradient_checkpointing=True):
390
+ super().__init__()
391
+ self.device = device
392
+ self.max_length = max_length
393
+ self.use_gradient_checkpointing = use_gradient_checkpointing
394
+ dtype = torch.bfloat16
395
+ self.dtype = dtype
396
+ t5_version = './t5-v1_1-xxl'
397
+ self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length)
398
+
399
+ self.t5_encoder = T5EncoderModel.from_pretrained(
400
+ t5_version,
401
+ torch_dtype=dtype
402
+ ).to(device=device).to(dtype)
403
+
404
+ self.t5_encoder.requires_grad_(False)
405
+
406
+ # Add LoRA adapters to the T5 model
407
+ text_lora_config = LoraConfig(
408
+ r=rank,
409
+ lora_alpha=rank,
410
+ lora_dropout=0.0,
411
+ init_lora_weights="gaussian",
412
+ target_modules=["q", "k", "v", "o", "wi", "wo"],
413
+ )
414
+ self.t5_encoder.add_adapter(text_lora_config)
415
+ self.t5_encoder.encoder.embed_tokens.weight.requires_grad_(True)
416
+
417
+ # Manually implement gradient checkpointing for T5 encoder blocks
418
+ if self.use_gradient_checkpointing:
419
+ self._enable_gradient_checkpointing()
420
+
421
+ print(f"Gradient checkpointing enabled: {self.use_gradient_checkpointing}")
422
+
423
+ image_encoder_path = './clip-vit-large-patch14'
424
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
425
+ image_encoder_path
426
+ ).to(device=device).to(torch.bfloat16)
427
+ self.image_encoder = self.image_encoder.eval().requires_grad_(False)
428
+
429
+ def _enable_gradient_checkpointing(self):
430
+ """
431
+ Manually wrap T5 encoder blocks with gradient checkpointing.
432
+ """
433
+
434
+ def create_custom_forward(module):
435
+ def custom_forward(*inputs):
436
+ return module(*inputs)
437
+
438
+ return custom_forward
439
+
440
+ # Wrap each T5 block with checkpointing
441
+ for block in self.t5_encoder.encoder.block:
442
+ # Store original forward
443
+ block._original_forward = block.forward
444
+
445
+ # Create checkpointed forward
446
+ def make_checkpointed_forward(blk):
447
+ def checkpointed_forward(*args, **kwargs):
448
+ # Checkpoint requires a function that takes tensors as input
449
+ def forward_wrapper(*inputs):
450
+ # Reconstruct kwargs from inputs
451
+ hidden_states = inputs[0]
452
+ attention_mask = inputs[1] if len(inputs) > 1 else None
453
+ position_bias = inputs[2] if len(inputs) > 2 else None
454
+
455
+ return blk._original_forward(
456
+ hidden_states=hidden_states,
457
+ attention_mask=attention_mask,
458
+ position_bias=position_bias,
459
+ **{k: v for k, v in kwargs.items() if
460
+ k not in ['hidden_states', 'attention_mask', 'position_bias']}
461
+ )
462
+
463
+ # Prepare inputs for checkpointing
464
+ hidden_states = kwargs.get('hidden_states', args[0] if args else None)
465
+ attention_mask = kwargs.get('attention_mask', args[1] if len(args) > 1 else None)
466
+ position_bias = kwargs.get('position_bias', args[2] if len(args) > 2 else None)
467
+
468
+ # Use checkpoint
469
+ checkpoint_inputs = [hidden_states]
470
+ if attention_mask is not None:
471
+ checkpoint_inputs.append(attention_mask)
472
+ if position_bias is not None:
473
+ checkpoint_inputs.append(position_bias)
474
+
475
+ return checkpoint.checkpoint(
476
+ forward_wrapper,
477
+ *checkpoint_inputs,
478
+ use_reentrant=False
479
+ )
480
+
481
+ return checkpointed_forward
482
+
483
+ block.forward = make_checkpointed_forward(block)
484
+
485
+ def _encode_text(self, input_ids):
486
+ """Helper function to encode text through T5."""
487
+ return self.t5_encoder(
488
+ input_ids=input_ids.to(self.device),
489
+ attention_mask=None,
490
+ output_hidden_states=False,
491
+ )['last_hidden_state']
492
+
493
+ def compute_perturbation_loss(self, prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding):
494
+ """
495
+ Compute group lasso for non-pad non-change tokens, L1 for change tokens,
496
+ and group sparsity for pad non-change tokens.
497
+
498
+ Args:
499
+ prompt_embeds: Original embeddings [batch_size, seq_len, hidden_dim]
500
+ perturbed_prompt_embeds: Perturbed embeddings [batch_size, seq_len, hidden_dim]
501
+ replaced_ids: List of replaced token indices for each sample in batch
502
+ batch_encoding: The tokenizer output containing input_ids
503
+
504
+ Returns:
505
+ l2_loss: Group lasso loss for non-pad non-change tokens (scalar tensor)
506
+ l1_loss: L1 loss for change tokens (scalar tensor)
507
+ pad_group_loss: Group sparsity loss for pad non-change tokens (scalar tensor)
508
+ """
509
+ batch_size = prompt_embeds.size(0)
510
+ pad_token_id = self.t5_tokenizer.pad_token_id
511
+ input_ids = batch_encoding["input_ids"]
512
+
513
+ l2_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
514
+ l1_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
515
+ pad_group_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
516
+
517
+ # Track valid samples for each loss type separately
518
+ l1_valid_samples = 0
519
+ l2_valid_samples = 0
520
+ pad_valid_samples = 0
521
+
522
+ for i in range(batch_size):
523
+ # Get the replaced index for this sample
524
+ replaced_idx = replaced_ids[i]
525
+
526
+ if replaced_idx is None:
527
+ # No replacement happened (all padding), skip
528
+ continue
529
+
530
+ # Find padding and non-padding token indices
531
+ pad_mask = input_ids[i] == pad_token_id
532
+ non_pad_mask = ~pad_mask
533
+
534
+ pad_indices = torch.where(pad_mask)[0]
535
+ non_pad_indices = torch.where(non_pad_mask)[0]
536
+
537
+ # Filter out the replaced index from non-padding indices (non-pad non-change)
538
+ non_selected_non_pad_indices = non_pad_indices[non_pad_indices != replaced_idx]
539
+
540
+ # Compute L1 loss on selected (replaced) index - CHANGE TOKEN
541
+ selected_diff = prompt_embeds[i, replaced_idx] - perturbed_prompt_embeds[i, replaced_idx]
542
+ l1_loss_total = l1_loss_total + torch.abs(selected_diff).mean()
543
+ l1_valid_samples += 1
544
+
545
+ # Compute group lasso (L2) loss on NON-PAD NON-CHANGE tokens
546
+ if len(non_selected_non_pad_indices) > 0:
547
+ non_selected_diff = prompt_embeds[i, non_selected_non_pad_indices] - perturbed_prompt_embeds[
548
+ i, non_selected_non_pad_indices]
549
+ l2_per_token = torch.sqrt((non_selected_diff ** 2).sum(dim=1))
550
+ l2_loss_total = l2_loss_total + l2_per_token.mean()
551
+ l2_valid_samples += 1
552
+
553
+ # Compute group sparsity loss on PAD NON-CHANGE tokens
554
+ if len(pad_indices) > 0:
555
+ pad_diff = prompt_embeds[i, pad_indices] - perturbed_prompt_embeds[i, pad_indices]
556
+ # Group sparsity: L2 norm per token (encourages entire token embeddings to be zero)
557
+ pad_group_per_token = torch.sqrt((pad_diff ** 2).sum(dim=1))
558
+ pad_group_loss_total = pad_group_loss_total + pad_group_per_token.mean()
559
+ pad_valid_samples += 1
560
+
561
+ # Average over valid samples for each loss type
562
+ l2_loss = l2_loss_total / l2_valid_samples if l2_valid_samples > 0 else torch.tensor(0.0,
563
+ device=prompt_embeds.device)
564
+ l1_loss = l1_loss_total / l1_valid_samples if l1_valid_samples > 0 else torch.tensor(0.0,
565
+ device=prompt_embeds.device)
566
+ pad_group_loss = pad_group_loss_total / pad_valid_samples if pad_valid_samples > 0 else torch.tensor(0.0,
567
+ device=prompt_embeds.device)
568
+
569
+ return l2_loss, l1_loss, pad_group_loss
570
+
571
+ def forward(self, text, image=None):
572
+ if isinstance(text, str):
573
+ text = [text]
574
+ batch_encoding = self.t5_tokenizer(
575
+ text,
576
+ truncation=True,
577
+ max_length=self.max_length,
578
+ return_length=False,
579
+ return_overflowing_tokens=False,
580
+ padding="max_length",
581
+ return_tensors="pt",
582
+ )
583
+ attn_mask = batch_encoding["attention_mask"].to(self.device)
584
+
585
+ # First encoding
586
+ prompt_embeds = self._encode_text(batch_encoding["input_ids"])
587
+
588
+ # Get input_ids and create a copy to modify
589
+ input_ids = batch_encoding["input_ids"].clone()
590
+ batch_size = input_ids.size(0)
591
+
592
+ # Get the padding token id
593
+ # get the id for the first sentinel token
594
+ mask_token = "<extra_id_0>"
595
+ mask_token_id = self.t5_tokenizer.convert_tokens_to_ids(mask_token)
596
+ pad_token_id = self.t5_tokenizer.pad_token_id
597
+
598
+ replaced_ids = []
599
+ # For each sample in the batch
600
+ for i in range(batch_size):
601
+ # Find indices of non-padding tokens
602
+ non_pad_mask = input_ids[i] != pad_token_id
603
+ non_pad_indices = torch.where(non_pad_mask)[0]
604
+
605
+ # If there are meaningful tokens, randomly select one to replace
606
+ if len(non_pad_indices) > 0:
607
+ # Randomly select an index from non-padding tokens
608
+ random_idx = non_pad_indices[random.randint(0, len(non_pad_indices) - 1)]
609
+ random_idx2 = non_pad_indices[random.randint(0, len(non_pad_indices) - 1)]
610
+ # Replace with padding token
611
+ input_ids[i, random_idx] = mask_token_id
612
+ replaced_ids.append(random_idx.item())
613
+ else:
614
+ replaced_ids.append(None) # No replacement if all tokens are padding
615
+
616
+ # Second encoding with perturbed input
617
+ perturbed_prompt_embeds = self._encode_text(input_ids)
618
+
619
+ """
620
+ l2_loss, l1_loss, pad_loss = self.compute_perturbation_loss(
621
+ prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding
622
+ )
623
+ """
624
+
625
+ with torch.no_grad():
626
+ if image is not None:
627
+ clip_image_embeds = self.image_encoder(image.to(self.device)).image_embeds
628
+ else:
629
+ clip_image_embeds = None
630
+
631
+ #return prompt_embeds, l2_loss, l1_loss, pad_loss, clip_image_embeds, attn_mask
632
+ return prompt_embeds, clip_image_embeds, perturbed_prompt_embeds, replaced_ids, self.t5_tokenizer, batch_encoding
633
+
634
+
635
+
636
+
637
+
638
+ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
639
+ class QwenEmbedder(nn.Module):
640
+ def __init__(self, device, max_length=512):
641
+ super().__init__()
642
+ self.device = device
643
+ self.max_length = max_length
644
+ dtype = torch.bfloat16
645
+ self.dtype = dtype
646
+ self.tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=True)
647
+ self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
648
+ "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=dtype,
649
+ ).to(device=device)
650
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
651
+ self.prompt_template_encode_start_idx = 34
652
+ self.tokenizer_max_length = max_length
653
+
654
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
655
+ bool_mask = mask.bool()
656
+ valid_lengths = bool_mask.sum(dim=1)
657
+ selected = hidden_states[bool_mask]
658
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
659
+
660
+ return split_result
661
+
662
+ def _get_qwen_prompt_embeds(
663
+ self,
664
+ prompt = None,
665
+ device = None,
666
+ dtype = None,
667
+ ):
668
+ device = device or self._execution_device
669
+ dtype = dtype or self.text_encoder.dtype
670
+
671
+ prompt = [prompt] if isinstance(prompt, str) else prompt
672
+
673
+ template = self.prompt_template_encode
674
+ drop_idx = self.prompt_template_encode_start_idx
675
+ txt = [template.format(e) for e in prompt]
676
+ txt_tokens = self.tokenizer(
677
+ txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
678
+ ).to(device)
679
+ encoder_hidden_states = self.text_encoder(
680
+ input_ids=txt_tokens.input_ids,
681
+ attention_mask=txt_tokens.attention_mask,
682
+ output_hidden_states=True,
683
+ )
684
+ hidden_states = encoder_hidden_states.hidden_states[-1]
685
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
686
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
687
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
688
+ #max_seq_len = max([e.size(0) for e in split_hidden_states])
689
+ max_seq_len = self.max_length
690
+ prompt_embeds = torch.stack(
691
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
692
+ )
693
+ encoder_attention_mask = torch.stack(
694
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
695
+ )
696
+
697
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
698
+
699
+ return prompt_embeds, encoder_attention_mask
700
+
701
+ @torch.no_grad()
702
+ def forward(self, text):
703
+ prompt_embeds, attention_mask = self._get_qwen_prompt_embeds(
704
+ prompt=text,
705
+ device=self.device,
706
+ dtype=self.dtype,
707
+ )
708
+ return prompt_embeds, attention_mask
709
+
710
+
711
+
712
+ # pip install accelerate
713
+
714
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration
715
+ from PIL import Image
716
+ import requests
717
+ import torch
718
+ import torch.nn as nn
719
+ Qwen25VL_7b_PREFIX_edit = '''Given an user editing prompt and an source image, only describe the editing area and how they should change in a detailed way.
720
+ Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
721
+ '''
722
+
723
+ Qwen25VL_7b_PREFIX_t2i = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
724
+ - If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
725
+ - If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
726
+ Here are examples of how to transform or refine prompts:
727
+ - User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
728
+ - User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
729
+ Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
730
+ User Prompt:'''
731
+ Qwen25VL_7b_PREFIX_image = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
732
+ model_id = "google/gemma-3-4b-it"
733
+
734
+ from transformers import AutoTokenizer, TrainingArguments, Gemma3ForCausalLM, AutoModel, Gemma3Model
735
+ from transformers import Dinov2Model, AutoImageProcessor
736
+
737
+ import torch
738
+ import torchvision.transforms as transforms
739
+ import torchvision.models as models
740
+ from PIL import Image
741
+ import numpy as np
742
+
743
+
744
+
745
+ class GemmaEmbedder(nn.Module):
746
+ def __init__(self, max_sequence_length=300, model_id='google/gemma-3-4b-it'):
747
+ super().__init__()
748
+ device = torch.cuda.current_device()
749
+ self.model = Gemma3Model.from_pretrained(model_id).to(device).to(torch.bfloat16)
750
+ #self.model = Gemma3ForConditionalGeneration.from_pretrained(model_id).to(device).to(torch.bfloat16)
751
+ self.processor = AutoProcessor.from_pretrained(model_id)
752
+ self.device = device
753
+ self.max_sequence_length = max_sequence_length
754
+ #self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token # Use eos token as pad token
755
+ self.processor.tokenizer.padding_side = "right"
756
+
757
+
758
+
759
+ def get_features(self, hidden_states, input_ids):
760
+ hidden_states = hidden_states[0]
761
+ input_ids = input_ids[0].tolist()
762
+
763
+ pad_text_embeds = torch.zeros([self.max_sequence_length, 2560], dtype=torch.bfloat16, device=self.device)
764
+ pad_text_mask = torch.zeros([self.max_sequence_length], device=self.device)
765
+
766
+
767
+ def find_last(lst, value):
768
+ indices = [i for i, x in enumerate(lst) if x == value]
769
+ return indices[-1]
770
+
771
+ if 256000 in input_ids:
772
+ text_start = input_ids.index(256000)+2
773
+ else:
774
+ text_start = find_last(input_ids, 108)+1
775
+
776
+ text_end = len(input_ids)-6
777
+ bos_embed = hidden_states[:2]
778
+ text_embeds = hidden_states[text_start:text_end + 1]
779
+ text_embeds = torch.cat([bos_embed, text_embeds], dim=0)
780
+
781
+ pad_text_embeds[:len(text_embeds), :] = text_embeds[:self.max_sequence_length]
782
+ pad_text_mask[:len(text_embeds)] = 1.0
783
+
784
+ image_embeds = hidden_states[np.array(input_ids) == self.processor.tokenizer.image_token_id]
785
+
786
+ """
787
+ print(input_ids)
788
+ print(input_ids[text_start:text_end + 1])
789
+ decoded = self.processor.decode(input_ids[text_start:text_end+1], skip_special_tokens=False)
790
+ print("Decoded text:", decoded, text_start, text_end, input_ids[text_start:text_end + 1], input_ids[1:2])
791
+ print("Text embeddings shape:", text_embeds.shape)
792
+ norm = RMSNorm(2560, eps=1e-6).to(self.device).to(torch.bfloat16)
793
+ print(text_embeds, ' >>> ext embeds')
794
+ print(norm(text_embeds), ' >>> normed embeds')
795
+ """
796
+
797
+ return image_embeds, pad_text_embeds, pad_text_mask
798
+
799
+
800
+
801
+ @torch.no_grad()
802
+ def forward(self, caps, images=None):
803
+ text_embeds = []
804
+ text_masks = []
805
+ full_image_embeds = []
806
+ device = self.model.device
807
+ if images is None:
808
+ images = [None] * len(caps)
809
+ for cap,img in zip(caps, images):
810
+ if img is not None:
811
+ messages = [
812
+ {
813
+ "role": "system",
814
+ "content": [{"type": "text", "text": Qwen25VL_7b_PREFIX_edit}]
815
+ },
816
+ {
817
+ "role": "user",
818
+ "content": [
819
+ {"type": "image", "image": img},
820
+ {"type": "text", "text": cap},
821
+ ]
822
+ }
823
+ ]
824
+ else:
825
+ messages = [
826
+ {
827
+ "role": "system",
828
+ "content": [{"type": "text", "text": Qwen25VL_7b_PREFIX_t2i}]
829
+ },
830
+ {
831
+ "role": "user",
832
+ "content": [
833
+ {"type": "text", "text": cap},
834
+ ]
835
+ }
836
+ ]
837
+
838
+ inputs = self.processor.apply_chat_template(
839
+ messages, add_generation_prompt=True, tokenize=True,
840
+ return_dict=True, return_tensors="pt",
841
+ max_length = 640,
842
+ truncation = True,
843
+ ).to(self.model.device, dtype=torch.bfloat16)
844
+ outputs = self.model(**inputs, output_hidden_states=True)
845
+ #sample_image_embeds = outputs.image_hidden_states
846
+ sample_text_embeds, sample_text_mask, sample_image_embeds = [], [], []
847
+ for hidden in [outputs.hidden_states[-1]]:
848
+ cur_image_embeds, cur_text_embeds, cur_text_mask = self.get_features(hidden, inputs["input_ids"])
849
+ sample_text_embeds.append(cur_text_embeds)
850
+ sample_text_mask.append(cur_text_mask)
851
+ sample_image_embeds.append(cur_image_embeds)
852
+ text_embeds.append(torch.cat(sample_text_embeds, dim=0))
853
+ text_masks.append(torch.cat(sample_text_mask, dim=0))
854
+ #full_image_embeds.append(sample_image_embeds)
855
+ full_image_embeds.append(torch.cat(sample_image_embeds, dim=0))
856
+
857
+ """
858
+ input_len = inputs["input_ids"].shape[-1]
859
+ with torch.inference_mode():
860
+ generation = self.model.generate(**inputs, max_new_tokens=100, do_sample=False)
861
+ generation = generation[0][input_len:]
862
+
863
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
864
+ print(cap, ' <>>> gemma ',decoded)
865
+ """
866
+
867
+ text_embeds = torch.stack(text_embeds, dim=0)
868
+ text_masks = torch.stack(text_masks, dim=0)
869
+ full_image_embeds = torch.stack(full_image_embeds, dim=0)
870
+ return {
871
+ 'text_embeds': text_embeds,
872
+ 'text_masks': text_masks,
873
+ 'image_embeds': full_image_embeds,
874
+ }
875
+
876
+ class GemmaTextEmbedder(nn.Module):
877
+ def __init__(self, device, max_sequence_length=300, model_id='./gemma-3-4b-it'):
878
+ super().__init__()
879
+ self.model = Gemma3Model.from_pretrained(model_id).to(device).to(torch.bfloat16)
880
+ #self.model = Gemma3ForConditionalGeneration.from_pretrained(model_id).to(device).to(torch.bfloat16)
881
+ self.processor = AutoProcessor.from_pretrained(model_id)
882
+ self.real_device = device
883
+ self.max_sequence_length = max_sequence_length
884
+ #self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token # Use eos token as pad token
885
+ self.processor.tokenizer.padding_side = "right"
886
+
887
+ @property
888
+ def dtype(self):
889
+ """Return the dtype of the model parameters."""
890
+ return next(self.parameters()).dtype
891
+
892
+ @property
893
+ def device(self):
894
+ """Return the device of the model parameters."""
895
+ return next(self.parameters()).device
896
+
897
+ def get_features(self, hidden_states, input_ids):
898
+ hidden_states = hidden_states[0]
899
+ input_ids = input_ids[0].tolist()
900
+
901
+ pad_text_embeds = torch.zeros([self.max_sequence_length, 2560], dtype=torch.bfloat16, device=self.device)
902
+ pad_text_mask = torch.zeros([self.max_sequence_length], device=self.device)
903
+
904
+
905
+ def find_last(lst, value):
906
+ indices = [i for i, x in enumerate(lst) if x == value]
907
+ return indices[-1]
908
+
909
+ if 256000 in input_ids:
910
+ text_start = input_ids.index(256000)+2
911
+ else:
912
+ text_start = find_last(input_ids, 108)+1
913
+
914
+ text_end = len(input_ids)-6
915
+ bos_embed = hidden_states[:2]
916
+ text_embeds = hidden_states[text_start:text_end + 1]
917
+ text_embeds = torch.cat([bos_embed, text_embeds], dim=0)
918
+
919
+
920
+ pad_text_embeds[:len(text_embeds), :] = text_embeds[:self.max_sequence_length]
921
+ pad_text_mask[:len(text_embeds)] = 1.0
922
+
923
+
924
+ pad_text_embeds[len(text_embeds):, :] = 0.0
925
+ pad_text_mask[len(text_embeds):] = 0.0
926
+
927
+ """
928
+ print(input_ids)
929
+ print(input_ids[text_start:text_end + 1])
930
+ decoded = self.processor.decode(input_ids[text_start:text_end+1], skip_special_tokens=False)
931
+ print("Decoded text:", decoded, text_start, text_end, input_ids[text_start:text_end + 1], input_ids[1:2])
932
+ print("Text embeddings shape:", text_embeds.shape)
933
+ print(text_embeds, ' >>> ext embeds')
934
+ """
935
+
936
+
937
+ return pad_text_embeds, pad_text_mask
938
+
939
+
940
+
941
+ @torch.no_grad()
942
+ def forward(self, caps, images=None):
943
+ text_embeds = []
944
+ text_masks = []
945
+ full_image_embeds = []
946
+ device = self.model.device
947
+ if isinstance(caps, str):
948
+ caps = [caps]
949
+ if images is None:
950
+ images = [None] * len(caps)
951
+ for cap,img in zip(caps, images):
952
+ if img is not None:
953
+ messages = [
954
+ {
955
+ "role": "system",
956
+ "content": [{"type": "text", "text": Qwen25VL_7b_PREFIX_edit}]
957
+ },
958
+ {
959
+ "role": "user",
960
+ "content": [
961
+ {"type": "image", "image": img},
962
+ {"type": "text", "text": cap},
963
+ ]
964
+ }
965
+ ]
966
+ else:
967
+ messages = [
968
+ {
969
+ "role": "system",
970
+ "content": [{"type": "text", "text": Qwen25VL_7b_PREFIX_t2i}]
971
+ },
972
+ {
973
+ "role": "user",
974
+ "content": [
975
+ {"type": "text", "text": cap},
976
+ ]
977
+ }
978
+ ]
979
+
980
+ inputs = self.processor.apply_chat_template(
981
+ messages, add_generation_prompt=True, tokenize=True,
982
+ return_dict=True, return_tensors="pt",
983
+ max_length = 640,
984
+ truncation = True,
985
+ ).to(self.model.device, dtype=torch.bfloat16)
986
+ outputs = self.model(**inputs, output_hidden_states=True)
987
+ #sample_image_embeds = outputs.image_hidden_states
988
+ sample_text_embeds, sample_text_mask, sample_image_embeds = [], [], []
989
+ for hidden in [outputs.hidden_states[-1]]:
990
+ cur_text_embeds, cur_text_mask = self.get_features(hidden, inputs["input_ids"])
991
+ sample_text_embeds.append(cur_text_embeds)
992
+ sample_text_mask.append(cur_text_mask)
993
+ text_embeds.append(torch.cat(sample_text_embeds, dim=0))
994
+ text_masks.append(torch.cat(sample_text_mask, dim=0))
995
+
996
+ """
997
+ input_len = inputs["input_ids"].shape[-1]
998
+ with torch.inference_mode():
999
+ generation = self.model.generate(**inputs, max_new_tokens=100, do_sample=False)
1000
+ generation = generation[0][input_len:]
1001
+
1002
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
1003
+ print(cap, ' <>>> gemma ',decoded)
1004
+ """
1005
+
1006
+ text_embeds = torch.stack(text_embeds, dim=0)
1007
+ text_masks = torch.stack(text_masks, dim=0)
1008
+ return text_embeds, text_masks.to(text_embeds.dtype)
1009
+
1010
+
1011
+
1012
+ from transformers import AutoModel, AutoTokenizer
1013
+ from transformers import SiglipVisionModel, AutoProcessor
1014
+ class Gemma2Embedder(nn.Module):
1015
+ def __init__(self, max_length=300):
1016
+ super().__init__()
1017
+ self.text_encoder = AutoModel.from_pretrained(
1018
+ "google/gemma-2-2b",
1019
+ torch_dtype=torch.bfloat16,
1020
+ ).to(torch.cuda.current_device()).to(torch.bfloat16).eval()
1021
+ self.tokenizer = AutoTokenizer.from_pretrained(
1022
+ "google/gemma-2-2b",
1023
+ )
1024
+ self.tokenizer.padding_side = "right"
1025
+ self.max_length = max_length
1026
+ self.system_prompt = "You are an assistant designed to edit images faithfully based on user prompts. <Prompt Start> "
1027
+ system_ids = self.tokenizer(
1028
+ self.system_prompt,
1029
+ return_tensors="pt",
1030
+ add_special_tokens=True,
1031
+ max_length=self.max_length,
1032
+ padding="max_length",
1033
+ truncation=True,
1034
+ ).input_ids.flatten().view(-1).numpy().tolist()
1035
+ self.len_system_prompt = system_ids.index(self.tokenizer.pad_token_id)-1
1036
+ self.weight_dtype = torch.bfloat16
1037
+
1038
+ @torch.no_grad()
1039
+ def forward(self, caption):
1040
+ if isinstance(caption, str):
1041
+ caption = [caption]
1042
+ caption = [self.system_prompt + c for c in caption]
1043
+ text_inputs = self.tokenizer(
1044
+ caption,
1045
+ return_tensors="pt",
1046
+ add_special_tokens=True,
1047
+ max_length=self.max_length+self.len_system_prompt,
1048
+ padding="max_length",
1049
+ truncation=True,
1050
+ )
1051
+ text_input_ids = text_inputs.input_ids
1052
+ attention_mask = text_inputs.attention_mask
1053
+ text_input_ids = text_input_ids.to(self.text_encoder.device)
1054
+ attention_mask = attention_mask.to(self.text_encoder.device)
1055
+ embeds = self.text_encoder(text_input_ids, attention_mask=attention_mask,
1056
+ output_hidden_states=True
1057
+ ).hidden_states[-2]
1058
+ embeds = embeds[:, self.len_system_prompt:, :]
1059
+ attention_mask = attention_mask[:, self.len_system_prompt:]
1060
+
1061
+
1062
+ return {
1063
+ 'text_embeds': embeds,
1064
+ 'text_masks': attention_mask,
1065
+ }
1066
+
1067
+
1068
+ class T5TextEmbedder(nn.Module):
1069
+ def __init__(self, device, pretrained_path="google/flan-t5-xxl", max_length=300):
1070
+ super().__init__()
1071
+ self.model = T5EncoderModel.from_pretrained(pretrained_path).to(device=device).to(torch.bfloat16)
1072
+ self.tokenizer = T5Tokenizer.from_pretrained(pretrained_path)
1073
+ self.max_length = max_length
1074
+ self.model.eval()
1075
+ self.model.requires_grad_(False)
1076
+
1077
+
1078
+ @property
1079
+ def dtype(self):
1080
+ """Return the dtype of the model parameters."""
1081
+ return next(self.parameters()).dtype
1082
+
1083
+ @property
1084
+ def device(self):
1085
+ """Return the device of the model parameters."""
1086
+ return next(self.parameters()).device
1087
+
1088
+ def forward(
1089
+ self, caption
1090
+ ):
1091
+ max_length = self.max_length
1092
+
1093
+ text_inputs = self.tokenizer(
1094
+ caption,
1095
+ return_tensors="pt",
1096
+ add_special_tokens=True,
1097
+ max_length=max_length,
1098
+ padding="max_length",
1099
+ truncation=True,
1100
+ )
1101
+ text_input_ids = text_inputs.input_ids
1102
+ attention_mask = text_inputs.attention_mask
1103
+ text_input_ids = text_input_ids.to(self.model.device)
1104
+ attention_mask = attention_mask.to(self.model.device)
1105
+ outputs = self.model(text_input_ids, attention_mask=attention_mask)
1106
+ embeddings = outputs.last_hidden_state
1107
+ return embeddings, attention_mask.to(embeddings.dtype)
1108
+
1109
+
1110
+
1111
+
1112
+
1113
+
1114
+ if __name__ == '__main__':
1115
+
1116
+
1117
+
1118
+ from datasets import load_dataset
1119
+ dataset = load_dataset("facebook/emu_edit_test_set", split='validation[:200]')
1120
+ item = dataset[0:4]
1121
+ another_item = dataset[0:4]
1122
+ from diffusers.models.normalization import RMSNorm
1123
+ image_encoder = CLIPImageEncoder(device="cuda:0")
1124
+ clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
1125
+ image_embeds = image_encoder(clip_processor(images=item['image'], return_tensors="pt").pixel_values.to("cuda:0").to(torch.bfloat16))
1126
+ print(image_embeds.shape, ' >>>> image embeds')
1127
+
1128
+
1129
+ #model = GemmaTextEmbedder(device="cuda:0")
1130
+ model = LoraT5Embedder(device="cuda:0")
1131
+ prompt_embeds, l2_loss, l1_loss, pad_loss, clip_image_embeds, attn_mask = model(
1132
+ [
1133
+ """A heartwarming 3D rendered scene of
1134
+ an elderly farmer and a tiny orange
1135
+ kitten. The farmer, with a gentle smile,
1136
+ walks alongside the kitten in a lush,
1137
+ green garden filled with thriving plants,
1138
+ showcasing a fruitful harvest. The
1139
+ intricate details of the overalls and the
1140
+ farmer's worn, weathered face tell a
1141
+ story of years spent tending to the land, the farmer is wearing a blue shirt""",
1142
+ ],
1143
+ image=clip_processor(images=item['image'], return_tensors="pt").pixel_values.to("cuda:0").to(torch.bfloat16
1144
+ ))
1145
+ print(l2_loss, ' >>> l2 loss ', l1_loss, ' >>> l1 loss ', pad_loss, ' >>> pad loss ')
1146
+ print(clip_image_embeds.shape, ' >>> clip image embeds ')
1147
+
1148
+ #print(gemma_dict['text_embeds'],)
1149
+ #print(gemma_dict['image_embeds'], ' >>> image embeds')
1150
+
1151
+
1152
+
1153
+ """
1154
+ from dataset import create_loader
1155
+ from PIL import Image as PILImage
1156
+ from PIL import Image as PILImage
1157
+ import PIL
1158
+ import numpy as np
1159
+ import torch.nn.functional as F
1160
+
1161
+ loader = create_loader('edit', batch_size=16, shuffle=False)
1162
+ batch = next(iter(loader))
1163
+ source = batch['source_images']
1164
+ source_pils = [PIL.Image.fromarray(((x.permute(1, 2, 0).cpu().numpy() + 1) * 127.5).astype(np.uint8)) for x in source]
1165
+ target = batch['target_images']
1166
+ target_pils = [PIL.Image.fromarray(((x.permute(1, 2, 0).cpu().numpy() + 1) * 127.5).astype(np.uint8)) for x in target]
1167
+ from torchvision.utils import save_image
1168
+
1169
+ print(batch['captions'])
1170
+
1171
+
1172
+ images = []
1173
+ for (x, y) in zip(batch['source_images'], batch['target_images']):
1174
+ images.append(x)
1175
+ images.append(y)
1176
+ save_image((torch.stack(images) + 1) / 2, 'example_pairs.jpg', nrow=8)
1177
+
1178
+ gemma_dict = model(batch['captions'], source_pils, target_pils)
1179
+ image_embeds = gemma_dict['image_embeds']
1180
+ target_image_embeds = gemma_dict['target_image_embeds']
1181
+
1182
+ print("Image embeds shape:", image_embeds.shape)
1183
+ print("Target image embeds shape:", target_image_embeds.shape)
1184
+ from qwen import compute_and_save_similarity_grid
1185
+ compute_and_save_similarity_grid(image_embeds, target_image_embeds, "gemma_similarity_grid.jpg")
1186
+ """
1187
+
1188
+