Spaces:
Running
on
Zero
Running
on
Zero
Update ledits/pipeline_leditspp_stable_diffusion_xl.py
Browse files
ledits/pipeline_leditspp_stable_diffusion_xl.py
CHANGED
@@ -415,10 +415,11 @@ class LEditsPPPipelineStableDiffusionXL(
|
|
415 |
editing_prompt: Optional[str] = None,
|
416 |
editing_prompt_embeds: Optional[torch.Tensor] = None,
|
417 |
editing_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
418 |
-
avg_diff
|
419 |
-
|
420 |
-
correlation_weight_factor
|
421 |
scale=2,
|
|
|
422 |
) -> object:
|
423 |
r"""
|
424 |
Encodes the prompt into text encoder hidden states.
|
@@ -538,9 +539,8 @@ class LEditsPPPipelineStableDiffusionXL(
|
|
538 |
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
539 |
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
540 |
|
541 |
-
if avg_diff is not None
|
542 |
-
#scale=3
|
543 |
-
print("SHALOM neg")
|
544 |
normed_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True)
|
545 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
546 |
if j == 0:
|
@@ -549,15 +549,26 @@ class LEditsPPPipelineStableDiffusionXL(
|
|
549 |
standard_weights = torch.ones_like(weights)
|
550 |
|
551 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
552 |
-
edit_concepts_embeds = negative_prompt_embeds + (
|
|
|
|
|
|
|
|
|
|
|
|
|
553 |
else:
|
554 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
555 |
|
556 |
standard_weights = torch.ones_like(weights)
|
557 |
|
558 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
559 |
-
edit_concepts_embeds = negative_prompt_embeds + (
|
|
|
560 |
|
|
|
|
|
|
|
|
|
561 |
|
562 |
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
563 |
j+=1
|
@@ -878,10 +889,12 @@ class LEditsPPPipelineStableDiffusionXL(
|
|
878 |
clip_skip: Optional[int] = None,
|
879 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
880 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
881 |
-
avg_diff
|
882 |
-
|
883 |
-
correlation_weight_factor
|
884 |
scale=2,
|
|
|
|
|
885 |
init_latents: [torch.Tensor] = None,
|
886 |
zs: [torch.Tensor] = None,
|
887 |
**kwargs,
|
@@ -1088,9 +1101,10 @@ class LEditsPPPipelineStableDiffusionXL(
|
|
1088 |
editing_prompt_embeds=editing_prompt_embeddings,
|
1089 |
editing_pooled_prompt_embeds=editing_pooled_prompt_embeds,
|
1090 |
avg_diff = avg_diff,
|
1091 |
-
|
1092 |
correlation_weight_factor = correlation_weight_factor,
|
1093 |
scale=scale,
|
|
|
1094 |
)
|
1095 |
|
1096 |
# 4. Prepare timesteps
|
|
|
415 |
editing_prompt: Optional[str] = None,
|
416 |
editing_prompt_embeds: Optional[torch.Tensor] = None,
|
417 |
editing_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
418 |
+
avg_diff=None, # [0] -> text encoder 1,[1] ->text encoder 2
|
419 |
+
avg_diff_2nd=None, # text encoder 1,2
|
420 |
+
correlation_weight_factor=0.7,
|
421 |
scale=2,
|
422 |
+
scale_2nd=2,
|
423 |
) -> object:
|
424 |
r"""
|
425 |
Encodes the prompt into text encoder hidden states.
|
|
|
539 |
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
540 |
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
541 |
|
542 |
+
if avg_diff is not None:
|
543 |
+
# scale=3
|
|
|
544 |
normed_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True)
|
545 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
546 |
if j == 0:
|
|
|
549 |
standard_weights = torch.ones_like(weights)
|
550 |
|
551 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
552 |
+
edit_concepts_embeds = negative_prompt_embeds + (
|
553 |
+
weights * avg_diff[0][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
|
554 |
+
|
555 |
+
if avg_diff_2nd is not None:
|
556 |
+
edit_concepts_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1,
|
557 |
+
self.pipe.tokenizer.model_max_length,
|
558 |
+
1) * scale_2nd)
|
559 |
else:
|
560 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
561 |
|
562 |
standard_weights = torch.ones_like(weights)
|
563 |
|
564 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
565 |
+
edit_concepts_embeds = negative_prompt_embeds + (
|
566 |
+
weights * avg_diff[1][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
|
567 |
|
568 |
+
if avg_diff_2nd is not None:
|
569 |
+
edit_concepts_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1,
|
570 |
+
self.pipe.tokenizer_2.model_max_length,
|
571 |
+
1) * scale_2nd)
|
572 |
|
573 |
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
574 |
j+=1
|
|
|
889 |
clip_skip: Optional[int] = None,
|
890 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
891 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
892 |
+
avg_diff=None, # [0] -> text encoder 1,[1] ->text encoder 2
|
893 |
+
avg_diff_2nd=None, # text encoder 1,2
|
894 |
+
correlation_weight_factor=0.7,
|
895 |
scale=2,
|
896 |
+
scale_2nd=2,
|
897 |
+
correlation_weight_factor = 0.7,
|
898 |
init_latents: [torch.Tensor] = None,
|
899 |
zs: [torch.Tensor] = None,
|
900 |
**kwargs,
|
|
|
1101 |
editing_prompt_embeds=editing_prompt_embeddings,
|
1102 |
editing_pooled_prompt_embeds=editing_pooled_prompt_embeds,
|
1103 |
avg_diff = avg_diff,
|
1104 |
+
avg_diff_2nd = avg_diff_2nd,
|
1105 |
correlation_weight_factor = correlation_weight_factor,
|
1106 |
scale=scale,
|
1107 |
+
scale_2nd=scale_2nd
|
1108 |
)
|
1109 |
|
1110 |
# 4. Prepare timesteps
|