Linoy Tsaban commited on
Commit
277aca5
1 Parent(s): 4eb55a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -15
app.py CHANGED
@@ -50,25 +50,22 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
  sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
51
  sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
52
  sem_pipe = SemanticStableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
 
53
 
54
-
55
- def edit(input_image,
56
  src_prompt,
57
  tar_prompt,
58
  steps,
59
- src_cfg_scale,
60
  skip,
61
- tar_cfg_scale,
62
- edit_concept,
63
- sega_edit_guidance,
64
- warm_up,
65
- neg_guidance):
66
  offsets=(0,0,0,0)
67
  x0 = load_512(input_image, *offsets, device)
68
 
69
 
70
  # invert
71
- wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
 
72
  latnets = wts[skip].expand(1, -1, -1, -1)
73
 
74
  eta = 1
@@ -76,7 +73,15 @@ def edit(input_image,
76
  pure_ddpm_out = sample(wt, zs, wts, prompt_tar=tar_prompt,
77
  cfg_scale_tar=tar_cfg_scale, skip=skip,
78
  eta = eta)
79
-
 
 
 
 
 
 
 
 
80
  editing_args = dict(
81
  editing_prompt = [edit_concept],
82
  reverse_editing_direction = [neg_guidance],
@@ -90,7 +95,48 @@ def edit(input_image,
90
  num_images_per_prompt=1,
91
  num_inference_steps=steps,
92
  use_ddpm=True, wts=wts, zs=zs[skip:], **editing_args)
93
- return pure_ddpm_out,sega_out.images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
 
96
  ####################################
@@ -132,7 +178,7 @@ with gr.Blocks() as demo:
132
  with gr.Row():
133
  #inversion
134
  steps = gr.Number(value=100, precision=0, label="Steps", interactive=True)
135
- src_cfg_scale = gr.Number(value=3.5, label=f"Source CFG", interactive=True)
136
  # reconstruction
137
  skip = gr.Number(value=36, precision=0, label="Skip", interactive=True)
138
  tar_cfg_scale = gr.Number(value=15, label=f"Reconstruction CFG", interactive=True)
@@ -146,20 +192,28 @@ with gr.Blocks() as demo:
146
  # gr.Markdown(help_text)
147
 
148
  generate_button.click(
149
- fn=edit,
150
  inputs=[input_image,
151
  src_prompt,
152
  tar_prompt,
153
  steps,
154
  src_cfg_scale,
155
  skip,
156
- tar_cfg_scale,
 
 
 
 
 
 
 
 
157
  edit_concept,
158
  sega_edit_guidance,
159
  warm_up,
160
  neg_guidance
161
  ],
162
- outputs=[ddpm_edited_image, sega_edited_image],
163
  )
164
 
165
 
 
50
  sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
51
  sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
52
  sem_pipe = SemanticStableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
53
+ latents, wts, zs = None, None, None
54
 
55
+ def invert_and_reconstruct(input_image,
 
56
  src_prompt,
57
  tar_prompt,
58
  steps,
59
+ # src_cfg_scale,
60
  skip,
61
+ tar_cfg_scale):
 
 
 
 
62
  offsets=(0,0,0,0)
63
  x0 = load_512(input_image, *offsets, device)
64
 
65
 
66
  # invert
67
+ # wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
68
+ wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps)
69
  latnets = wts[skip].expand(1, -1, -1, -1)
70
 
71
  eta = 1
 
73
  pure_ddpm_out = sample(wt, zs, wts, prompt_tar=tar_prompt,
74
  cfg_scale_tar=tar_cfg_scale, skip=skip,
75
  eta = eta)
76
+ return pure_ddpm_out
77
+
78
+ def edit( tar_prompt,
79
+ steps,
80
+ edit_concept,
81
+ sega_edit_guidance,
82
+ warm_up,
83
+ neg_guidance):
84
+
85
  editing_args = dict(
86
  editing_prompt = [edit_concept],
87
  reverse_editing_direction = [neg_guidance],
 
95
  num_images_per_prompt=1,
96
  num_inference_steps=steps,
97
  use_ddpm=True, wts=wts, zs=zs[skip:], **editing_args)
98
+ return sega_out.images[0]
99
+
100
+ # def edit(input_image,
101
+ # src_prompt,
102
+ # tar_prompt,
103
+ # steps,
104
+ # # src_cfg_scale,
105
+ # skip,
106
+ # tar_cfg_scale,
107
+ # edit_concept,
108
+ # sega_edit_guidance,
109
+ # warm_up,
110
+ # neg_guidance):
111
+ # offsets=(0,0,0,0)
112
+ # x0 = load_512(input_image, *offsets, device)
113
+
114
+
115
+ # # invert
116
+ # # wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
117
+ # wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps)
118
+ # latnets = wts[skip].expand(1, -1, -1, -1)
119
+
120
+ # eta = 1
121
+ # #pure DDPM output
122
+ # pure_ddpm_out = sample(wt, zs, wts, prompt_tar=tar_prompt,
123
+ # cfg_scale_tar=tar_cfg_scale, skip=skip,
124
+ # eta = eta)
125
+
126
+ # editing_args = dict(
127
+ # editing_prompt = [edit_concept],
128
+ # reverse_editing_direction = [neg_guidance],
129
+ # edit_warmup_steps=[warm_up],
130
+ # edit_guidance_scale=[sega_edit_guidance],
131
+ # edit_threshold=[.93],
132
+ # edit_momentum_scale=0.5,
133
+ # edit_mom_beta=0.6
134
+ # )
135
+ # sega_out = sem_pipe(prompt=tar_prompt,eta=eta, latents=latnets,
136
+ # num_images_per_prompt=1,
137
+ # num_inference_steps=steps,
138
+ # use_ddpm=True, wts=wts, zs=zs[skip:], **editing_args)
139
+ # return pure_ddpm_out,sega_out.images[0]
140
 
141
 
142
  ####################################
 
178
  with gr.Row():
179
  #inversion
180
  steps = gr.Number(value=100, precision=0, label="Steps", interactive=True)
181
+ # src_cfg_scale = gr.Number(value=3.5, label=f"Source CFG", interactive=True)
182
  # reconstruction
183
  skip = gr.Number(value=36, precision=0, label="Skip", interactive=True)
184
  tar_cfg_scale = gr.Number(value=15, label=f"Reconstruction CFG", interactive=True)
 
192
  # gr.Markdown(help_text)
193
 
194
  generate_button.click(
195
+ fn=invert_and_reconstruct,
196
  inputs=[input_image,
197
  src_prompt,
198
  tar_prompt,
199
  steps,
200
  src_cfg_scale,
201
  skip,
202
+ tar_cfg_scale
203
+ ],
204
+ outputs=[ddpm_edited_image],
205
+ )
206
+
207
+ edit_button.click(
208
+ fn=edit,
209
+ inputs=[tar_prompt,
210
+ steps,
211
  edit_concept,
212
  sega_edit_guidance,
213
  warm_up,
214
  neg_guidance
215
  ],
216
+ outputs=[sega_edited_image],
217
  )
218
 
219