Linoy Tsaban commited on
Commit
98c6b44
1 Parent(s): 001613c

Update app.py

Browse files

add option for multiple warm-up steps

Files changed (1) hide show
  1. app.py +21 -7
app.py CHANGED
@@ -8,6 +8,7 @@ from utils import *
8
  from inversion_utils import *
9
  from modified_pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline
10
  from torch import autocast, inference_mode
 
11
 
12
  def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
13
 
@@ -74,7 +75,7 @@ def edit(input_image,
74
  tar_cfg_scale=15,
75
  edit_concept="",
76
  sega_edit_guidance=0,
77
- # warm_up=1,
78
  # neg_guidance=False,
79
  left = 0,
80
  right = 0,
@@ -98,8 +99,11 @@ def edit(input_image,
98
 
99
  if not edit_concept or not sega_edit_guidance:
100
  return pure_ddpm_out, pure_ddpm_out
 
101
  # SEGA
 
102
  edit_concepts = edit_concept.split(",")
 
103
  neg_guidance =[]
104
  for edit_concept in edit_concepts:
105
  if edit_concept.startswith("-"):
@@ -107,15 +111,25 @@ def edit(input_image,
107
  else:
108
  neg_guidance.append(False)
109
  edit_concepts = [concept.strip("+|-") for concept in edit_concepts]
110
-
111
- default_warm_up = [1]*len(edit_concepts)
112
-
 
 
 
 
 
 
 
 
 
 
113
  editing_args = dict(
114
  editing_prompt = edit_concepts,
115
  reverse_editing_direction = neg_guidance,
116
- edit_warmup_steps=default_warm_up,
117
- edit_guidance_scale=[sega_edit_guidance],
118
- edit_threshold=[.93],
119
  edit_momentum_scale=0.5,
120
  edit_mom_beta=0.6
121
  )
 
8
  from inversion_utils import *
9
  from modified_pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline
10
  from torch import autocast, inference_mode
11
+ import re
12
 
13
  def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
14
 
 
75
  tar_cfg_scale=15,
76
  edit_concept="",
77
  sega_edit_guidance=0,
78
+ warm_up=None,
79
  # neg_guidance=False,
80
  left = 0,
81
  right = 0,
 
99
 
100
  if not edit_concept or not sega_edit_guidance:
101
  return pure_ddpm_out, pure_ddpm_out
102
+
103
  # SEGA
104
+ # parse concepts and neg guidance
105
  edit_concepts = edit_concept.split(",")
106
+ num_concepts = len(edit_concepts)
107
  neg_guidance =[]
108
  for edit_concept in edit_concepts:
109
  if edit_concept.startswith("-"):
 
111
  else:
112
  neg_guidance.append(False)
113
  edit_concepts = [concept.strip("+|-") for concept in edit_concepts]
114
+
115
+ # parse warm-up steps
116
+ default_warm_up_steps = [1]*num_concepts
117
+ if warm_up:
118
+ digit_pattern = re.compile(r"^\d+$")
119
+ warm_up_steps_str = warm_up.split(",")
120
+ for i,num_steps in enumerate(warm_up_steps[:num_concepts]):
121
+ if not digit_pattern.match(num_steps):
122
+ raise gr.Error("Invalid value for warm-up steps, using 1 instead")
123
+ else:
124
+ default_warm_up_steps[i] = int(num_steps)
125
+
126
+
127
  editing_args = dict(
128
  editing_prompt = edit_concepts,
129
  reverse_editing_direction = neg_guidance,
130
+ edit_warmup_steps=default_warm_up_steps,
131
+ edit_guidance_scale=[sega_edit_guidance]*num_concepts,
132
+ edit_threshold=[.93]*num_concepts,
133
  edit_momentum_scale=0.5,
134
  edit_mom_beta=0.6
135
  )