linoyts HF staff commited on
Commit
6ef419c
Β·
verified Β·
1 Parent(s): b1c5569

take avg_diff out of attributes and save to cpu (#1)

Browse files

- take avg_diff out of attributes and save to cpu (a64b549e3ec96a9cf832c237af6bd7dd3baae807)
- Update clip_slider_pipeline.py (224212aeaf4f1ed103d432f5d08e8eee69601810)

Files changed (2) hide show
  1. app.py +9 -5
  2. clip_slider_pipeline.py +22 -165
app.py CHANGED
@@ -17,14 +17,14 @@ def generate(slider_x, slider_y, prompt,
17
 
18
  # check if avg diff for directions need to be re-calculated
19
  if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]):
20
- clip_slider.avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1])
21
  x_concept_1, x_concept_2 = slider_x[0], slider_x[1]
22
 
23
  if not sorted(slider_y) == sorted([y_concept_1, y_concept_2]):
24
- clip_slider.avg_diff_2nd = clip_slider.find_latent_direction(slider_y[0], slider_y[1])
25
  y_concept_1, y_concept_2 = slider_y[0], slider_y[1]
26
 
27
- image = clip_slider.generate(prompt, scale=0, scale_2nd=0, num_inference_steps=8)
28
  comma_concepts_x = ', '.join(slider_x)
29
  comma_concepts_y = ', '.join(slider_y)
30
 
@@ -36,11 +36,15 @@ def generate(slider_x, slider_y, prompt,
36
  return gr.update(label=comma_concepts_x, interactive=True),gr.update(label=comma_concepts_y, interactive=True), x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, image
37
 
38
  def update_x(x,y,prompt, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2):
39
- image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8)
 
 
40
  return image
41
 
42
  def update_y(x,y,prompt, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2):
43
- image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8)
 
 
44
  return image
45
 
46
  css = '''
 
17
 
18
  # check if avg diff for directions need to be re-calculated
19
  if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]):
20
+ avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1])
21
  x_concept_1, x_concept_2 = slider_x[0], slider_x[1]
22
 
23
  if not sorted(slider_y) == sorted([y_concept_1, y_concept_2]):
24
+ avg_diff_2nd = clip_slider.find_latent_direction(slider_y[0], slider_y[1])
25
  y_concept_1, y_concept_2 = slider_y[0], slider_y[1]
26
 
27
+ image = clip_slider.generate(prompt, scale=0, scale_2nd=0, num_inference_steps=8, avg_diff=avg_diff, avg_diff_2nd=avg_diff_2nd)
28
  comma_concepts_x = ', '.join(slider_x)
29
  comma_concepts_y = ', '.join(slider_y)
30
 
 
36
  return gr.update(label=comma_concepts_x, interactive=True),gr.update(label=comma_concepts_y, interactive=True), x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, image
37
 
38
  def update_x(x,y,prompt, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2):
39
+ avg_diff = [avg_diff_x_1.cuda(), avg_diff_x_2.cuda()]
40
+ avg_diff_2nd = [avg_diff_y_1.cuda(), avg_diff_y_2.cuda()]
41
+ image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd)
42
  return image
43
 
44
  def update_y(x,y,prompt, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2):
45
+ avg_diff = [avg_diff_x_1.cuda(), avg_diff_x_2.cuda()]
46
+ avg_diff_2nd = [avg_diff_y_1.cuda(), avg_diff_y_2.cuda()]
47
+ image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd)
48
  return image
49
 
50
  css = '''
clip_slider_pipeline.py CHANGED
@@ -73,6 +73,8 @@ class CLIPSlider:
73
  only_pooler = False,
74
  normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
75
  correlation_weight_factor = 1.0,
 
 
76
  **pipeline_kwargs
77
  ):
78
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
@@ -83,14 +85,14 @@ class CLIPSlider:
83
  max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
84
  prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
85
 
86
- if self.avg_diff_2nd and normalize_scales:
87
  denominator = abs(scale) + abs(scale_2nd)
88
  scale = scale / denominator
89
  scale_2nd = scale_2nd / denominator
90
  if only_pooler:
91
- prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale
92
- if self.avg_diff_2nd:
93
- prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd
94
  else:
95
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
96
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
@@ -102,9 +104,9 @@ class CLIPSlider:
102
 
103
  # weights = torch.sigmoid((weights-0.5)*7)
104
  prompt_embeds = prompt_embeds + (
105
- weights * self.avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
106
- if self.avg_diff_2nd:
107
- prompt_embeds += weights * self.avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
108
 
109
 
110
  torch.manual_seed(seed)
@@ -198,6 +200,8 @@ class CLIPSliderXL(CLIPSlider):
198
  only_pooler = False,
199
  normalize_scales = False,
200
  correlation_weight_factor = 1.0,
 
 
201
  **pipeline_kwargs
202
  ):
203
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
@@ -232,16 +236,16 @@ class CLIPSliderXL(CLIPSlider):
232
  pooled_prompt_embeds = prompt_embeds[0]
233
  prompt_embeds = prompt_embeds.hidden_states[-2]
234
 
235
- if self.avg_diff_2nd and normalize_scales:
236
  denominator = abs(scale) + abs(scale_2nd)
237
  scale = scale / denominator
238
  scale_2nd = scale_2nd / denominator
239
  if only_pooler:
240
- prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale
241
- if self.avg_diff_2nd:
242
- prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd[0] * scale_2nd
243
  else:
244
- print(self.avg_diff)
245
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
246
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
247
 
@@ -251,18 +255,18 @@ class CLIPSliderXL(CLIPSlider):
251
  standard_weights = torch.ones_like(weights)
252
 
253
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
254
- prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
255
- if self.avg_diff_2nd:
256
- prompt_embeds += (weights * self.avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd)
257
  else:
258
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
259
 
260
  standard_weights = torch.ones_like(weights)
261
 
262
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
263
- prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
264
- if self.avg_diff_2nd:
265
- prompt_embeds += (weights * self.avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd)
266
 
267
  bs_embed, seq_len, _ = prompt_embeds.shape
268
  prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
@@ -276,150 +280,3 @@ class CLIPSliderXL(CLIPSlider):
276
  **pipeline_kwargs).images[0]
277
 
278
  return image
279
-
280
-
281
- class CLIPSlider3(CLIPSlider):
282
- def find_latent_direction(self,
283
- target_word:str,
284
- opposite:str):
285
-
286
- # lets identify a latent direction by taking differences between opposites
287
- # target_word = "happy"
288
- # opposite = "sad"
289
-
290
-
291
- with torch.no_grad():
292
- positives = []
293
- negatives = []
294
- positives2 = []
295
- negatives2 = []
296
- for i in tqdm(range(self.iterations)):
297
- medium = random.choice(MEDIUMS)
298
- subject = random.choice(SUBJECTS)
299
- pos_prompt = f"a {medium} of a {target_word} {subject}"
300
- neg_prompt = f"a {medium} of a {opposite} {subject}"
301
-
302
- pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
303
- max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
304
- neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
305
- max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
306
- pos = self.pipe.text_encoder(pos_toks).text_embeds
307
- neg = self.pipe.text_encoder(neg_toks).text_embeds
308
- positives.append(pos)
309
- negatives.append(neg)
310
-
311
- pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
312
- max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
313
- neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
314
- max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
315
- pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
316
- neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
317
- positives2.append(pos2)
318
- negatives2.append(neg2)
319
-
320
- positives = torch.cat(positives, dim=0)
321
- negatives = torch.cat(negatives, dim=0)
322
- diffs = positives - negatives
323
- avg_diff = diffs.mean(0, keepdim=True)
324
-
325
- positives2 = torch.cat(positives2, dim=0)
326
- negatives2 = torch.cat(negatives2, dim=0)
327
- diffs2 = positives2 - negatives2
328
- avg_diff2 = diffs2.mean(0, keepdim=True)
329
- return (avg_diff, avg_diff2)
330
-
331
- def generate(self,
332
- prompt = "a photo of a house",
333
- scale = 2,
334
- seed = 15,
335
- only_pooler = False,
336
- correlation_weight_factor = 1.0,
337
- ** pipeline_kwargs
338
- ):
339
- # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
340
- # if pooler token only [-4,4] work well
341
- clip_text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
342
- clip_tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
343
- with torch.no_grad():
344
- # toks = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.cuda()
345
- # prompt_embeds = pipe.text_encoder(toks).last_hidden_state
346
-
347
- clip_prompt_embeds_list = []
348
- clip_pooled_prompt_embeds_list = []
349
- for i, text_encoder in enumerate(clip_text_encoders):
350
-
351
- if i < 2:
352
- tokenizer = clip_tokenizers[i]
353
- text_inputs = tokenizer(
354
- prompt,
355
- padding="max_length",
356
- max_length=tokenizer.model_max_length,
357
- truncation=True,
358
- return_tensors="pt",
359
- )
360
- toks = text_inputs.input_ids
361
-
362
- prompt_embeds = text_encoder(
363
- toks.to(text_encoder.device),
364
- output_hidden_states=True,
365
- )
366
-
367
- # We are only ALWAYS interested in the pooled output of the final text encoder
368
- pooled_prompt_embeds = prompt_embeds[0]
369
- pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
370
- clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
371
- prompt_embeds = prompt_embeds.hidden_states[-2]
372
- else:
373
- text_inputs = self.pipe.tokenizer_3(
374
- prompt,
375
- padding="max_length",
376
- max_length=self.tokenizer_max_length,
377
- truncation=True,
378
- add_special_tokens=True,
379
- return_tensors="pt",
380
- )
381
- toks = text_inputs.input_ids
382
- prompt_embeds = self.pipe.text_encoder_3(toks.to(self.device))[0]
383
- t5_prompt_embed_shape = prompt_embeds.shape[-1]
384
-
385
- if only_pooler:
386
- prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale
387
- else:
388
- normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
389
- sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
390
- if i == 0:
391
- weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768)
392
-
393
- standard_weights = torch.ones_like(weights)
394
-
395
- weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
396
- prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
397
- else:
398
- weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
399
-
400
- standard_weights = torch.ones_like(weights)
401
-
402
- weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
403
- prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
404
-
405
- bs_embed, seq_len, _ = prompt_embeds.shape
406
- prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
407
- if i < 2:
408
- clip_prompt_embeds_list.append(prompt_embeds)
409
-
410
- clip_prompt_embeds = torch.concat(clip_prompt_embeds_list, dim=-1)
411
- clip_pooled_prompt_embeds = torch.concat(clip_pooled_prompt_embeds_list, dim=-1)
412
-
413
- clip_prompt_embeds = torch.nn.functional.pad(
414
- clip_prompt_embeds, (0, t5_prompt_embed_shape - clip_prompt_embeds.shape[-1])
415
- )
416
-
417
- prompt_embeds = torch.cat([clip_prompt_embeds, prompt_embeds], dim=-2)
418
-
419
-
420
-
421
- torch.manual_seed(seed)
422
- image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=clip_pooled_prompt_embeds,
423
- **pipeline_kwargs).images[0]
424
-
425
- return image
 
73
  only_pooler = False,
74
  normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
75
  correlation_weight_factor = 1.0,
76
+ avg_diff = None,
77
+ avg_diff_2nd = None,
78
  **pipeline_kwargs
79
  ):
80
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
 
85
  max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
86
  prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
87
 
88
+ if avg_diff_2nd and normalize_scales:
89
  denominator = abs(scale) + abs(scale_2nd)
90
  scale = scale / denominator
91
  scale_2nd = scale_2nd / denominator
92
  if only_pooler:
93
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale
94
+ if avg_diff_2nd:
95
+ prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd
96
  else:
97
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
98
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
 
104
 
105
  # weights = torch.sigmoid((weights-0.5)*7)
106
  prompt_embeds = prompt_embeds + (
107
+ weights * avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
108
+ if avg_diff_2nd:
109
+ prompt_embeds += weights * avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
110
 
111
 
112
  torch.manual_seed(seed)
 
200
  only_pooler = False,
201
  normalize_scales = False,
202
  correlation_weight_factor = 1.0,
203
+ avg_diff = None,
204
+ avg_diff_2nd = None,
205
  **pipeline_kwargs
206
  ):
207
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
 
236
  pooled_prompt_embeds = prompt_embeds[0]
237
  prompt_embeds = prompt_embeds.hidden_states[-2]
238
 
239
+ if avg_diff_2nd and normalize_scales:
240
  denominator = abs(scale) + abs(scale_2nd)
241
  scale = scale / denominator
242
  scale_2nd = scale_2nd / denominator
243
  if only_pooler:
244
+ prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff[0] * scale
245
+ if avg_diff_2nd:
246
+ prompt_embeds[:, toks.argmax()] += avg_diff_2nd[0] * scale_2nd
247
  else:
248
+ print(avg_diff)
249
  normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
250
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
251
 
 
255
  standard_weights = torch.ones_like(weights)
256
 
257
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
258
+ prompt_embeds = prompt_embeds + (weights * avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
259
+ if avg_diff_2nd:
260
+ prompt_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd)
261
  else:
262
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
263
 
264
  standard_weights = torch.ones_like(weights)
265
 
266
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
267
+ prompt_embeds = prompt_embeds + (weights * avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
268
+ if avg_diff_2nd:
269
+ prompt_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd)
270
 
271
  bs_embed, seq_len, _ = prompt_embeds.shape
272
  prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
 
280
  **pipeline_kwargs).images[0]
281
 
282
  return image