linoyts HF staff commited on
Commit
1314d69
·
verified ·
1 Parent(s): c9b0866

support flux (#1)

Browse files

- support flux (38137231e4da487d1d952256e68a109b7bc006cf)
- flux slider (325d5c60fdce05cac741dfba18eb2768296cea98)
- Update app.py (058c742984f5b9ca82b83f1eca0da8157c6aa320)

Files changed (2) hide show
  1. app.py +0 -0
  2. clip_slider_pipeline.py +171 -75
app.py CHANGED
The diff for this file is too large to render. See raw diff
 
clip_slider_pipeline.py CHANGED
@@ -4,26 +4,23 @@ import random
4
  from tqdm import tqdm
5
  from constants import SUBJECTS, MEDIUMS
6
  from PIL import Image
7
- import time
8
  class CLIPSlider:
9
  def __init__(
10
  self,
11
  sd_pipe,
12
  device: torch.device,
13
- target_word: str = "",
14
- opposite: str = "",
15
  target_word_2nd: str = "",
16
  opposite_2nd: str = "",
17
  iterations: int = 300,
18
  ):
19
 
20
  self.device = device
21
- self.pipe = sd_pipe.to(self.device, torch.float16)
22
  self.iterations = iterations
23
- if target_word != "" or opposite != "":
24
- self.avg_diff = self.find_latent_direction(target_word, opposite)
25
- else:
26
- self.avg_diff = None
27
  if target_word_2nd != "" or opposite_2nd != "":
28
  self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd)
29
  else:
@@ -32,21 +29,17 @@ class CLIPSlider:
32
 
33
  def find_latent_direction(self,
34
  target_word:str,
35
- opposite:str,
36
- num_iterations: int = None):
37
 
38
  # lets identify a latent direction by taking differences between opposites
39
  # target_word = "happy"
40
  # opposite = "sad"
41
 
42
- if num_iterations is not None:
43
- iterations = num_iterations
44
- else:
45
- iterations = self.iterations
46
  with torch.no_grad():
47
  positives = []
48
  negatives = []
49
- for i in tqdm(range(iterations)):
50
  medium = random.choice(MEDIUMS)
51
  subject = random.choice(SUBJECTS)
52
  pos_prompt = f"a {medium} of a {target_word} {subject}"
@@ -77,8 +70,6 @@ class CLIPSlider:
77
  only_pooler = False,
78
  normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
79
  correlation_weight_factor = 1.0,
80
- avg_diff = None,
81
- avg_diff_2nd = None,
82
  **pipeline_kwargs
83
  ):
84
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
@@ -89,14 +80,14 @@ class CLIPSlider:
89
  max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
90
  prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
91
 
92
- if avg_diff_2nd and normalize_scales:
93
  denominator = abs(scale) + abs(scale_2nd)
94
  scale = scale / denominator
95
  scale_2nd = scale_2nd / denominator
96
  if only_pooler:
97
- prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale
98
- if avg_diff_2nd:
99
- prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd
100
  else:
101
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
102
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
@@ -108,15 +99,15 @@ class CLIPSlider:
108
 
109
  # weights = torch.sigmoid((weights-0.5)*7)
110
  prompt_embeds = prompt_embeds + (
111
- weights * avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
112
- if avg_diff_2nd:
113
- prompt_embeds += weights * avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
114
 
115
 
116
  torch.manual_seed(seed)
117
- image = self.pipe(prompt_embeds=prompt_embeds, **pipeline_kwargs).images[0]
118
 
119
- return image
120
 
121
  def spectrum(self,
122
  prompt="a photo of a house",
@@ -149,23 +140,19 @@ class CLIPSliderXL(CLIPSlider):
149
 
150
  def find_latent_direction(self,
151
  target_word:str,
152
- opposite:str,
153
- num_iterations: int = None):
154
 
155
  # lets identify a latent direction by taking differences between opposites
156
  # target_word = "happy"
157
  # opposite = "sad"
158
- if num_iterations is not None:
159
- iterations = num_iterations
160
- else:
161
- iterations = self.iterations
162
 
163
  with torch.no_grad():
164
  positives = []
165
  negatives = []
166
  positives2 = []
167
  negatives2 = []
168
- for i in tqdm(range(iterations)):
169
  medium = random.choice(MEDIUMS)
170
  subject = random.choice(SUBJECTS)
171
  pos_prompt = f"a {medium} of a {target_word} {subject}"
@@ -208,13 +195,11 @@ class CLIPSliderXL(CLIPSlider):
208
  only_pooler = False,
209
  normalize_scales = False,
210
  correlation_weight_factor = 1.0,
211
- avg_diff = None,
212
- avg_diff_2nd = None,
213
  **pipeline_kwargs
214
  ):
215
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
216
  # if pooler token only [-4,4] work well
217
- start_time = time.time()
218
  text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
219
  tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
220
  with torch.no_grad():
@@ -239,21 +224,20 @@ class CLIPSliderXL(CLIPSlider):
239
  toks.to(text_encoder.device),
240
  output_hidden_states=True,
241
  )
242
-
243
  # We are only ALWAYS interested in the pooled output of the final text encoder
244
- pooled_prompt_embeds = prompt_embeds[0]
245
  prompt_embeds = prompt_embeds.hidden_states[-2]
246
- print("prompt_embeds.dtype",prompt_embeds.dtype)
247
- if avg_diff_2nd and normalize_scales:
248
  denominator = abs(scale) + abs(scale_2nd)
249
  scale = scale / denominator
250
  scale_2nd = scale_2nd / denominator
251
  if only_pooler:
252
- prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff[0] * scale
253
- if avg_diff_2nd:
254
- prompt_embeds[:, toks.argmax()] += avg_diff_2nd[0] * scale_2nd
255
  else:
256
-
257
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
258
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
259
 
@@ -263,58 +247,49 @@ class CLIPSliderXL(CLIPSlider):
263
  standard_weights = torch.ones_like(weights)
264
 
265
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
266
- prompt_embeds = prompt_embeds + (weights * avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
267
- if avg_diff_2nd:
268
- prompt_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd)
269
  else:
270
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
271
 
272
  standard_weights = torch.ones_like(weights)
273
 
274
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
275
- prompt_embeds = prompt_embeds + (weights * avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
276
- if avg_diff_2nd:
277
- prompt_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd)
278
 
279
  bs_embed, seq_len, _ = prompt_embeds.shape
280
  prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
281
  prompt_embeds_list.append(prompt_embeds)
282
 
283
- prompt_embeds = torch.concat(prompt_embeds_list, dim=-1).to(torch.float16)
284
- pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1).to(torch.float16)
285
- end_time = time.time()
286
- print("prompt_embeds", prompt_embeds.dtype)
287
- print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
288
  torch.manual_seed(seed)
289
- start_time = time.time()
290
- image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
291
- **pipeline_kwargs).images[0]
292
- end_time = time.time()
293
- print(f"generation time - pipe: {end_time - start_time:.2f} ms")
294
 
295
- return image
296
 
297
  class CLIPSliderXL_inv(CLIPSlider):
298
 
299
  def find_latent_direction(self,
300
  target_word:str,
301
- opposite:str,
302
- num_iterations: int = None):
303
 
304
  # lets identify a latent direction by taking differences between opposites
305
  # target_word = "happy"
306
  # opposite = "sad"
307
- if num_iterations is not None:
308
- iterations = num_iterations
309
- else:
310
- iterations = self.iterations
311
 
312
  with torch.no_grad():
313
  positives = []
314
  negatives = []
315
  positives2 = []
316
  negatives2 = []
317
- for i in tqdm(range(iterations)):
318
  medium = random.choice(MEDIUMS)
319
  subject = random.choice(SUBJECTS)
320
  pos_prompt = f"a {medium} of a {target_word} {subject}"
@@ -357,18 +332,139 @@ class CLIPSliderXL_inv(CLIPSlider):
357
  only_pooler = False,
358
  normalize_scales = False,
359
  correlation_weight_factor = 1.0,
360
- avg_diff=None,
361
- avg_diff_2nd=None,
362
- init_latents=None,
363
- zs=None,
364
  **pipeline_kwargs
365
  ):
366
 
367
  with torch.no_grad():
368
  torch.manual_seed(seed)
369
- images = self.pipe(editing_prompt=prompt, init_latents=init_latents, zs=zs,
370
- avg_diff=avg_diff[0], avg_diff_2=avg_diff[1],
371
- scale=scale,
372
  **pipeline_kwargs).images
373
 
374
  return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from tqdm import tqdm
5
  from constants import SUBJECTS, MEDIUMS
6
  from PIL import Image
7
+
8
  class CLIPSlider:
9
  def __init__(
10
  self,
11
  sd_pipe,
12
  device: torch.device,
13
+ target_word: str,
14
+ opposite: str,
15
  target_word_2nd: str = "",
16
  opposite_2nd: str = "",
17
  iterations: int = 300,
18
  ):
19
 
20
  self.device = device
21
+ self.pipe = sd_pipe.to(self.device)
22
  self.iterations = iterations
23
+ self.avg_diff = self.find_latent_direction(target_word, opposite)
 
 
 
24
  if target_word_2nd != "" or opposite_2nd != "":
25
  self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd)
26
  else:
 
29
 
30
  def find_latent_direction(self,
31
  target_word:str,
32
+ opposite:str):
 
33
 
34
  # lets identify a latent direction by taking differences between opposites
35
  # target_word = "happy"
36
  # opposite = "sad"
37
 
38
+
 
 
 
39
  with torch.no_grad():
40
  positives = []
41
  negatives = []
42
+ for i in tqdm(range(self.iterations)):
43
  medium = random.choice(MEDIUMS)
44
  subject = random.choice(SUBJECTS)
45
  pos_prompt = f"a {medium} of a {target_word} {subject}"
 
70
  only_pooler = False,
71
  normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
72
  correlation_weight_factor = 1.0,
 
 
73
  **pipeline_kwargs
74
  ):
75
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
 
80
  max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
81
  prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
82
 
83
+ if self.avg_diff_2nd and normalize_scales:
84
  denominator = abs(scale) + abs(scale_2nd)
85
  scale = scale / denominator
86
  scale_2nd = scale_2nd / denominator
87
  if only_pooler:
88
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale
89
+ if self.avg_diff_2nd:
90
+ prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd
91
  else:
92
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
93
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
 
99
 
100
  # weights = torch.sigmoid((weights-0.5)*7)
101
  prompt_embeds = prompt_embeds + (
102
+ weights * self.avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
103
+ if self.avg_diff_2nd:
104
+ prompt_embeds += weights * self.avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
105
 
106
 
107
  torch.manual_seed(seed)
108
+ images = self.pipe(prompt_embeds=prompt_embeds, **pipeline_kwargs).images
109
 
110
+ return images
111
 
112
  def spectrum(self,
113
  prompt="a photo of a house",
 
140
 
141
  def find_latent_direction(self,
142
  target_word:str,
143
+ opposite:str):
 
144
 
145
  # lets identify a latent direction by taking differences between opposites
146
  # target_word = "happy"
147
  # opposite = "sad"
148
+
 
 
 
149
 
150
  with torch.no_grad():
151
  positives = []
152
  negatives = []
153
  positives2 = []
154
  negatives2 = []
155
+ for i in tqdm(range(self.iterations)):
156
  medium = random.choice(MEDIUMS)
157
  subject = random.choice(SUBJECTS)
158
  pos_prompt = f"a {medium} of a {target_word} {subject}"
 
195
  only_pooler = False,
196
  normalize_scales = False,
197
  correlation_weight_factor = 1.0,
 
 
198
  **pipeline_kwargs
199
  ):
200
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
201
  # if pooler token only [-4,4] work well
202
+
203
  text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
204
  tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
205
  with torch.no_grad():
 
224
  toks.to(text_encoder.device),
225
  output_hidden_states=True,
226
  )
227
+
228
  # We are only ALWAYS interested in the pooled output of the final text encoder
229
+ pooled_prompt_embeds = prompt_embeds[0]
230
  prompt_embeds = prompt_embeds.hidden_states[-2]
231
+
232
+ if self.avg_diff_2nd and normalize_scales:
233
  denominator = abs(scale) + abs(scale_2nd)
234
  scale = scale / denominator
235
  scale_2nd = scale_2nd / denominator
236
  if only_pooler:
237
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale
238
+ if self.avg_diff_2nd:
239
+ prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd[0] * scale_2nd
240
  else:
 
241
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
242
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
243
 
 
247
  standard_weights = torch.ones_like(weights)
248
 
249
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
250
+ prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
251
+ if self.avg_diff_2nd:
252
+ prompt_embeds += (weights * self.avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd)
253
  else:
254
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
255
 
256
  standard_weights = torch.ones_like(weights)
257
 
258
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
259
+ prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
260
+ if self.avg_diff_2nd:
261
+ prompt_embeds += (weights * self.avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd)
262
 
263
  bs_embed, seq_len, _ = prompt_embeds.shape
264
  prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
265
  prompt_embeds_list.append(prompt_embeds)
266
 
267
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
268
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
269
+
 
 
270
  torch.manual_seed(seed)
271
+ images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
272
+ **pipeline_kwargs).images
 
 
 
273
 
274
+ return images
275
 
276
  class CLIPSliderXL_inv(CLIPSlider):
277
 
278
  def find_latent_direction(self,
279
  target_word:str,
280
+ opposite:str):
 
281
 
282
  # lets identify a latent direction by taking differences between opposites
283
  # target_word = "happy"
284
  # opposite = "sad"
285
+
 
 
 
286
 
287
  with torch.no_grad():
288
  positives = []
289
  negatives = []
290
  positives2 = []
291
  negatives2 = []
292
+ for i in tqdm(range(self.iterations)):
293
  medium = random.choice(MEDIUMS)
294
  subject = random.choice(SUBJECTS)
295
  pos_prompt = f"a {medium} of a {target_word} {subject}"
 
332
  only_pooler = False,
333
  normalize_scales = False,
334
  correlation_weight_factor = 1.0,
 
 
 
 
335
  **pipeline_kwargs
336
  ):
337
 
338
  with torch.no_grad():
339
  torch.manual_seed(seed)
340
+ images = self.pipe(editing_prompt=prompt,
341
+ avg_diff=self.avg_diff, avg_diff_2nd=self.avg_diff_2nd,
342
+ scale=scale, scale_2nd=scale_2nd,
343
  **pipeline_kwargs).images
344
 
345
  return images
346
+
347
+
348
+ class T5SliderFlux(CLIPSlider):
349
+
350
+ def find_latent_direction(self,
351
+ target_word:str,
352
+ opposite:str):
353
+
354
+ # lets identify a latent direction by taking differences between opposites
355
+ # target_word = "happy"
356
+ # opposite = "sad"
357
+
358
+
359
+ with torch.no_grad():
360
+ positives = []
361
+ negatives = []
362
+ for i in tqdm(range(self.iterations)):
363
+ medium = random.choice(MEDIUMS)
364
+ subject = random.choice(SUBJECTS)
365
+ pos_prompt = f"a {medium} of a {target_word} {subject}"
366
+ neg_prompt = f"a {medium} of a {opposite} {subject}"
367
+
368
+ pos_toks = self.pipe.tokenizer_2(pos_prompt,
369
+ return_tensors="pt",
370
+ padding="max_length",
371
+ truncation=True,
372
+ return_length=False,
373
+ return_overflowing_tokens=False,
374
+ max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
375
+ neg_toks = self.pipe.tokenizer_2(neg_prompt,
376
+ return_tensors="pt",
377
+ padding="max_length",
378
+ truncation=True,
379
+ return_length=False,
380
+ return_overflowing_tokens=False,
381
+ max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
382
+ pos = self.pipe.text_encoder_2(pos_toks, output_hidden_states=False)[0]
383
+ neg = self.pipe.text_encoder_2(neg_toks, output_hidden_states=False)[0]
384
+ positives.append(pos)
385
+ negatives.append(neg)
386
+
387
+ positives = torch.cat(positives, dim=0)
388
+ negatives = torch.cat(negatives, dim=0)
389
+ diffs = positives - negatives
390
+ avg_diff = diffs.mean(0, keepdim=True)
391
+
392
+ return avg_diff
393
+
394
+ def generate(self,
395
+ prompt = "a photo of a house",
396
+ scale = 2,
397
+ scale_2nd = 2,
398
+ seed = 15,
399
+ only_pooler = False,
400
+ normalize_scales = False,
401
+ correlation_weight_factor = 1.0,
402
+ **pipeline_kwargs
403
+ ):
404
+ # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
405
+ # if pooler token only [-4,4] work well
406
+
407
+ with torch.no_grad():
408
+ text_inputs = self.pipe.tokenizer(
409
+ prompt,
410
+ padding="max_length",
411
+ max_length=77,
412
+ truncation=True,
413
+ return_overflowing_tokens=False,
414
+ return_length=False,
415
+ return_tensors="pt",
416
+ )
417
+
418
+ text_input_ids = text_inputs.input_ids
419
+ prompt_embeds = self.pipe.text_encoder(text_input_ids.to(self.device), output_hidden_states=False)
420
+
421
+ # Use pooled output of CLIPTextModel
422
+ prompt_embeds = prompt_embeds.pooler_output
423
+ pooled_prompt_embeds = prompt_embeds.to(dtype=self.pipe.text_encoder.dtype, device=self.device)
424
+
425
+ # Use pooled output of CLIPTextModel
426
+
427
+ text_inputs = self.pipe.tokenizer_2(
428
+ prompt,
429
+ padding="max_length",
430
+ max_length=512,
431
+ truncation=True,
432
+ return_length=False,
433
+ return_overflowing_tokens=False,
434
+ return_tensors="pt",
435
+ )
436
+ toks = text_inputs.input_ids
437
+ prompt_embeds = self.pipe.text_encoder_2(toks.to(self.device), output_hidden_states=False)[0]
438
+ dtype = self.pipe.text_encoder_2.dtype
439
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device)
440
+ print("1", prompt_embeds.shape)
441
+ if self.avg_diff_2nd and normalize_scales:
442
+ denominator = abs(scale) + abs(scale_2nd)
443
+ scale = scale / denominator
444
+ scale_2nd = scale_2nd / denominator
445
+ if only_pooler:
446
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale
447
+ if self.avg_diff_2nd:
448
+ prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd
449
+ else:
450
+ normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
451
+ sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
452
+
453
+ weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, prompt_embeds.shape[2])
454
+ print("weights", weights.shape)
455
+
456
+ standard_weights = torch.ones_like(weights)
457
+
458
+ weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
459
+ prompt_embeds = prompt_embeds + (
460
+ weights * self.avg_diff * scale)
461
+ print("2", prompt_embeds.shape)
462
+ if self.avg_diff_2nd:
463
+ prompt_embeds += (
464
+ weights * self.avg_diff_2nd * scale_2nd)
465
+
466
+ torch.manual_seed(seed)
467
+ images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
468
+ **pipeline_kwargs).images
469
+
470
+ return images