Linoy Tsaban commited on
Commit
7cbd357
1 Parent(s): 26f7dab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -10
app.py CHANGED
@@ -110,7 +110,6 @@ def get_example():
110
  ]]
111
  return case
112
 
113
- inversion_map = dict()
114
 
115
  def invert_and_reconstruct(
116
  input_image,
@@ -142,14 +141,12 @@ def invert_and_reconstruct(
142
  #pure DDPM output
143
  pure_ddpm_out = sample(wt, zs, wts, prompt_tar=tar_prompt,
144
  cfg_scale_tar=tar_cfg_scale, skip=skip)
145
- inversion_map['latnets'] = latnets
146
- inversion_map['zs'] = zs
147
- inversion_map['wts'] = wts
148
 
149
  return pure_ddpm_out
150
 
151
- def reset():
152
- inversion_map.clear()
153
 
154
  def edit(input_image,
155
  src_prompt ="",
@@ -165,10 +162,18 @@ def edit(input_image,
165
  seed =0,
166
  ):
167
  torch.manual_seed(seed)
168
- if not bool(inversion_map):
169
- raise gr.Error("Must invert before editing")
170
- latnets, zs, wts = inversion_map['latnets'],inversion_map['zs'],inversion_map['wts']
171
-
 
 
 
 
 
 
 
 
172
  # SEGA
173
  # parse concepts and neg guidance
174
  edit_concepts = edit_concept.split(",")
 
110
  ]]
111
  return case
112
 
 
113
 
114
  def invert_and_reconstruct(
115
  input_image,
 
141
  #pure DDPM output
142
  pure_ddpm_out = sample(wt, zs, wts, prompt_tar=tar_prompt,
143
  cfg_scale_tar=tar_cfg_scale, skip=skip)
144
+ # inversion_map['latnets'] = latnets
145
+ # inversion_map['zs'] = zs
146
+ # inversion_map['wts'] = wts
147
 
148
  return pure_ddpm_out
149
 
 
 
150
 
151
  def edit(input_image,
152
  src_prompt ="",
 
162
  seed =0,
163
  ):
164
  torch.manual_seed(seed)
165
+ # if not bool(inversion_map):
166
+ # raise gr.Error("Must invert before editing")
167
+ # latnets, zs, wts = inversion_map['latnets'],inversion_map['zs'],inversion_map['wts']
168
+
169
+ x0 = load_512(input_image, left,right, top, bottom, device)
170
+
171
+ # invert
172
+ # wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
173
+ wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps)
174
+
175
+ latnets = wts[skip].expand(1, -1, -1, -1)
176
+
177
  # SEGA
178
  # parse concepts and neg guidance
179
  edit_concepts = edit_concept.split(",")