Linoy Tsaban commited on
Commit
7078734
·
1 Parent(s): bf289f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -72
app.py CHANGED
@@ -50,15 +50,19 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
  sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
51
  sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
52
  sem_pipe = SemanticStableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
53
- latents, wts, zs = None, None, None
54
 
55
- def invert_and_reconstruct(input_image,
 
56
  src_prompt,
57
  tar_prompt,
58
  steps,
59
  # src_cfg_scale,
60
  skip,
61
- tar_cfg_scale):
 
 
 
 
62
  offsets=(0,0,0,0)
63
  x0 = load_512(input_image, *offsets, device)
64
 
@@ -73,16 +77,7 @@ def invert_and_reconstruct(input_image,
73
  pure_ddpm_out = sample(wt, zs, wts, prompt_tar=tar_prompt,
74
  cfg_scale_tar=tar_cfg_scale, skip=skip,
75
  eta = eta)
76
- return pure_ddpm_out
77
-
78
- def edit( input_image,
79
- tar_prompt,
80
- steps,
81
- edit_concept,
82
- sega_edit_guidance,
83
- warm_up,
84
- neg_guidance):
85
-
86
  editing_args = dict(
87
  editing_prompt = [edit_concept],
88
  reverse_editing_direction = [neg_guidance],
@@ -96,48 +91,7 @@ def edit( input_image,
96
  num_images_per_prompt=1,
97
  num_inference_steps=steps,
98
  use_ddpm=True, wts=wts, zs=zs[skip:], **editing_args)
99
- return sega_out.images[0]
100
-
101
- # def edit(input_image,
102
- # src_prompt,
103
- # tar_prompt,
104
- # steps,
105
- # # src_cfg_scale,
106
- # skip,
107
- # tar_cfg_scale,
108
- # edit_concept,
109
- # sega_edit_guidance,
110
- # warm_up,
111
- # neg_guidance):
112
- # offsets=(0,0,0,0)
113
- # x0 = load_512(input_image, *offsets, device)
114
-
115
-
116
- # # invert
117
- # # wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
118
- # wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps)
119
- # latnets = wts[skip].expand(1, -1, -1, -1)
120
-
121
- # eta = 1
122
- # #pure DDPM output
123
- # pure_ddpm_out = sample(wt, zs, wts, prompt_tar=tar_prompt,
124
- # cfg_scale_tar=tar_cfg_scale, skip=skip,
125
- # eta = eta)
126
-
127
- # editing_args = dict(
128
- # editing_prompt = [edit_concept],
129
- # reverse_editing_direction = [neg_guidance],
130
- # edit_warmup_steps=[warm_up],
131
- # edit_guidance_scale=[sega_edit_guidance],
132
- # edit_threshold=[.93],
133
- # edit_momentum_scale=0.5,
134
- # edit_mom_beta=0.6
135
- # )
136
- # sega_out = sem_pipe(prompt=tar_prompt,eta=eta, latents=latnets,
137
- # num_images_per_prompt=1,
138
- # num_inference_steps=steps,
139
- # use_ddpm=True, wts=wts, zs=zs[skip:], **editing_args)
140
- # return pure_ddpm_out,sega_out.images[0]
141
 
142
 
143
  ####################################
@@ -163,9 +117,7 @@ with gr.Blocks() as demo:
163
 
164
  with gr.Row():
165
  with gr.Column(scale=1, min_width=100):
166
- generate_button = gr.Button("Invert")
167
- with gr.Column(scale=1, min_width=100):
168
- edit_button = gr.Button("Edit")
169
  # with gr.Column(scale=1, min_width=100):
170
  # reset_button = gr.Button("Reset")
171
  # with gr.Column(scale=3):
@@ -193,7 +145,7 @@ with gr.Blocks() as demo:
193
  # gr.Markdown(help_text)
194
 
195
  generate_button.click(
196
- fn=invert_and_reconstruct,
197
  inputs=[input_image,
198
  src_prompt,
199
  tar_prompt,
@@ -205,19 +157,6 @@ with gr.Blocks() as demo:
205
  outputs=[ddpm_edited_image],
206
  )
207
 
208
- edit_button.click(
209
- fn=edit,
210
- inputs=[
211
- input_image,
212
- tar_prompt,
213
- steps,
214
- edit_concept,
215
- sega_edit_guidance,
216
- warm_up,
217
- neg_guidance
218
- ],
219
- outputs=[sega_edited_image],
220
- )
221
 
222
 
223
  demo.queue(concurrency_count=1)
 
50
  sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
51
  sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
52
  sem_pipe = SemanticStableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
 
53
 
54
+
55
+ def edit(input_image,
56
  src_prompt,
57
  tar_prompt,
58
  steps,
59
  # src_cfg_scale,
60
  skip,
61
+ tar_cfg_scale,
62
+ edit_concept,
63
+ sega_edit_guidance,
64
+ warm_up,
65
+ neg_guidance):
66
  offsets=(0,0,0,0)
67
  x0 = load_512(input_image, *offsets, device)
68
 
 
77
  pure_ddpm_out = sample(wt, zs, wts, prompt_tar=tar_prompt,
78
  cfg_scale_tar=tar_cfg_scale, skip=skip,
79
  eta = eta)
80
+
 
 
 
 
 
 
 
 
 
81
  editing_args = dict(
82
  editing_prompt = [edit_concept],
83
  reverse_editing_direction = [neg_guidance],
 
91
  num_images_per_prompt=1,
92
  num_inference_steps=steps,
93
  use_ddpm=True, wts=wts, zs=zs[skip:], **editing_args)
94
+ return pure_ddpm_out,sega_out.images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
  ####################################
 
117
 
118
  with gr.Row():
119
  with gr.Column(scale=1, min_width=100):
120
+ generate_button = gr.Button("Generate")
 
 
121
  # with gr.Column(scale=1, min_width=100):
122
  # reset_button = gr.Button("Reset")
123
  # with gr.Column(scale=3):
 
145
  # gr.Markdown(help_text)
146
 
147
  generate_button.click(
148
+ fn=edit,
149
  inputs=[input_image,
150
  src_prompt,
151
  tar_prompt,
 
157
  outputs=[ddpm_edited_image],
158
  )
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
 
162
  demo.queue(concurrency_count=1)