multimodalart HF staff commited on
Commit
968ec9f
1 Parent(s): 1a2ccb3
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -106,7 +106,7 @@ def load_and_invert(
106
  progress=gr.Progress(track_tqdm=True),
107
  ):
108
  # x0 = load_512(input_image, device=device).to(torch.float16)
109
- do_inversion = True
110
  if do_inversion or randomize_seed:
111
  seed = randomize_seed_fn(seed, randomize_seed)
112
  seed_everything(seed)
@@ -121,8 +121,8 @@ def load_and_invert(
121
  )
122
  wts = wts_tensor
123
  zs = zs_tensor
124
- #do_inversion = False
125
-
126
  return wts, zs, do_inversion, gr.update(visible=False)
127
 
128
  ## SEGA ##
@@ -159,8 +159,7 @@ def edit(input_image,
159
  elif(mask_type=="Intersect Mask"):
160
  use_cross_attn_mask = False
161
  use_intersect_mask = True
162
-
163
- do_inversion = True
164
  if randomize_seed:
165
  seed = randomize_seed_fn(seed, randomize_seed)
166
  seed_everything(seed)
@@ -176,7 +175,7 @@ def edit(input_image,
176
  )
177
  wts = wts_tensor
178
  zs = zs_tensor
179
- #do_inversion = False
180
 
181
  if image_caption.lower() == tar_prompt.lower(): # if image caption was not changed, run pure sega
182
  tar_prompt = ""
 
106
  progress=gr.Progress(track_tqdm=True),
107
  ):
108
  # x0 = load_512(input_image, device=device).to(torch.float16)
109
+
110
  if do_inversion or randomize_seed:
111
  seed = randomize_seed_fn(seed, randomize_seed)
112
  seed_everything(seed)
 
121
  )
122
  wts = wts_tensor
123
  zs = zs_tensor
124
+ do_inversion = False
125
+
126
  return wts, zs, do_inversion, gr.update(visible=False)
127
 
128
  ## SEGA ##
 
159
  elif(mask_type=="Intersect Mask"):
160
  use_cross_attn_mask = False
161
  use_intersect_mask = True
162
+
 
163
  if randomize_seed:
164
  seed = randomize_seed_fn(seed, randomize_seed)
165
  seed_everything(seed)
 
175
  )
176
  wts = wts_tensor
177
  zs = zs_tensor
178
+ do_inversion = False
179
 
180
  if image_caption.lower() == tar_prompt.lower(): # if image caption was not changed, run pure sega
181
  tar_prompt = ""