Linoy Tsaban commited on
Commit
4697625
1 Parent(s): acc80f0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +259 -0
app.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import requests
4
+ from io import BytesIO
5
+ from diffusers import StableDiffusionPipeline
6
+ from diffusers import DDIMScheduler
7
+ from utils import *
8
+ from inversion_utils import *
9
+ from torch import autocast, inference_mode
10
+ import re
11
+
12
+ def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
13
+
14
+ # inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
15
+ # based on the code in https://github.com/inbarhub/DDPM_inversion
16
+
17
+ # returns wt, zs, wts:
18
+ # wt - inverted latent
19
+ # wts - intermediate inverted latents
20
+ # zs - noise maps
21
+
22
+ sd_pipe.scheduler.set_timesteps(num_diffusion_steps)
23
+
24
+ # vae encode image
25
+ with autocast("cuda"), inference_mode():
26
+ w0 = (sd_pipe.vae.encode(x0).latent_dist.mode() * 0.18215).float()
27
+
28
+ # find Zs and wts - forward process
29
+ wt, zs, wts = inversion_forward_process(sd_pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=True, num_inference_steps=num_diffusion_steps)
30
+ return wt, zs, wts
31
+
32
+
33
+
34
+ def sample(wt, zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
35
+
36
+ # reverse process (via Zs and wT)
37
+ w0, _ = inversion_reverse_process(sd_pipe, xT=wts[skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[skip:])
38
+
39
+ # vae decode image
40
+ with autocast("cuda"), inference_mode():
41
+ x0_dec = sd_pipe.vae.decode(1 / 0.18215 * w0).sample
42
+ if x0_dec.dim()<4:
43
+ x0_dec = x0_dec[None,:,:,:]
44
+ img = image_grid(x0_dec)
45
+ return img
46
+
47
+ # load pipelines
48
+ # sd_model_id = "runwayml/stable-diffusion-v1-5"
49
+ sd_model_id = "CompVis/stable-diffusion-v1-4"
50
+ # sd_model_id = "stabilityai/stable-diffusion-2-base"
51
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
53
+ sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
54
+
55
+
56
+
57
+ def get_example():
58
+ case = [
59
+ [
60
+ 'examples/source_a_man_wearing_a_brown_hoodie_in_a_crowded_street.jpeg',
61
+ 'a man wearing a brown hoodie in a crowded street',
62
+ 'a robot wearing a brown hoodie in a crowded street',
63
+ 100,
64
+ 36,
65
+ 15,
66
+ '+painting',
67
+ 10,
68
+ 1,
69
+ 'examples/ddpm_a_robot_wearing_a_brown_hoodie_in_a_crowded_street.png',
70
+ 'examples/ddpm_sega_painting_of_a_robot_wearing_a_brown_hoodie_in_a_crowded_street.png'
71
+ ],
72
+ [
73
+ 'examples/source_wall_with_framed_photos.jpeg',
74
+ '',
75
+ '',
76
+ 100,
77
+ 36,
78
+ 15,
79
+ '+pink drawings of muffins',
80
+ 10,
81
+ 1,
82
+ 'examples/ddpm_wall_with_framed_photos.png',
83
+ 'examples/ddpm_sega_plus_pink_drawings_of_muffins.png'
84
+ ],
85
+ [
86
+ 'examples/source_an_empty_room_with_concrete_walls.jpg',
87
+ 'an empty room with concrete walls',
88
+ 'glass walls',
89
+ 100,
90
+ 36,
91
+ 17,
92
+ '+giant elephant',
93
+ 10,
94
+ 1,
95
+ 'examples/ddpm_glass_walls.png',
96
+ 'examples/ddpm_sega_glass_walls_gian_elephant.png'
97
+ ]]
98
+ return case
99
+
100
+ inversion_map = dict()
101
+
102
+ def invert(input_image,
103
+ src_prompt ="",
104
+ steps=100,
105
+ src_cfg_scale = 3.5,
106
+ left = 0,
107
+ right = 0,
108
+ top = 0,
109
+ bottom = 0
110
+ ):
111
+ # offsets=(0,0,0,0)
112
+ x0 = load_512(input_image, left,right, top, bottom, device)
113
+
114
+
115
+ # invert
116
+ wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
117
+
118
+ latnets = wts[skip].expand(1, -1, -1, -1)
119
+ inversion_map['latnets'] = latnets
120
+ inversion_map['zs'] = zs
121
+ inversion_map['wts'] = wts
122
+
123
+
124
+
125
+
126
+ return
127
+
128
+ def edit(tar_prompt="",
129
+ steps=100,
130
+ skip=36,
131
+ tar_cfg_scale=15,
132
+
133
+ ):
134
+ outputs = []
135
+ num_generations = 1
136
+ for i in range(num_generations):
137
+ out = sample(wt, zs, wts, prompt_tar=tar_prompt,
138
+ cfg_scale_tar=tar_cfg_scale, skip=skip)
139
+ outputs.append(out)
140
+
141
+ return outputs
142
+
143
+ def reset():
144
+ inversion_map.clear()
145
+
146
+
147
+ ########
148
+ # demo #
149
+ ########
150
+
151
+ intro = """
152
+ <h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">
153
+ Edit Friendly DDPM Inversion
154
+ </h1>
155
+ <p style="font-size: 0.9rem; text-align: center; margin: 0rem; line-height: 1.2em; margin-top:1em">
156
+ <a href="https://arxiv.org/abs/2301.12247" style="text-decoration: underline;" target="_blank">An Edit Friendly DDPM Noise Space:
157
+ Inversion and Manipulations </a>
158
+ <p/>
159
+ <p style="font-size: 0.9rem; margin: 0rem; line-height: 1.2em; margin-top:1em">
160
+ For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
161
+ <a href="https://huggingface.co/spaces/LinoyTsaban/ddpm_sega?duplicate=true">
162
+ <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
163
+ <p/>"""
164
+ with gr.Blocks() as demo:
165
+ gr.HTML(intro)
166
+ with gr.Row():
167
+ src_prompt = gr.Textbox(lines=1, label="Source Prompt", interactive=True, placeholder="optional: describe the original image")
168
+ tar_prompt = gr.Textbox(lines=1, label="Target Prompt", interactive=True, placeholder="optional: describe the target image to edit with DDPM")
169
+
170
+ with gr.Row():
171
+ input_image = gr.Image(label="Input Image", interactive=True)
172
+ input_image.style(height=512, width=512)
173
+ output_image = gr.Image(label=f"Edited Image", interactive=False)
174
+ output_image.style(height=512, width=512)
175
+
176
+
177
+ with gr.Row():
178
+ with gr.Column(scale=1, min_width=100):
179
+ invert_button = gr.Button("Load & Invert")
180
+ with gr.Column(scale=1, min_width=100):
181
+ edit_button = gr.Button("Sample & Edit")
182
+
183
+
184
+ with gr.Accordion("Advanced Options", open=False):
185
+ with gr.Row():
186
+ with gr.Column():
187
+ #inversion
188
+ steps = gr.Number(value=100, precision=0, label="Num Diffusion Steps", interactive=True)
189
+ src_cfg_scale = gr.Slider(minimum=1, maximum=15, value=3.5, label=f"Source Guidance Scale", interactive=True)
190
+
191
+ # reconstruction
192
+ skip = gr.Slider(minimum=0, maximum=40, value=36, precision=0, label="Skip Steps", interactive=True)
193
+ tar_cfg_scale = gr.Slider(minimum=7, maximum=18,value=15, label=f"Target Guidance Scale", interactive=True)
194
+
195
+ #shift
196
+ with gr.Column():
197
+ left = gr.Number(value=0, precision=0, label="Left Shift", interactive=True)
198
+ right = gr.Number(value=0, precision=0, label="Right Shift", interactive=True)
199
+ top = gr.Number(value=0, precision=0, label="Top Shift", interactive=True)
200
+ bottom = gr.Number(value=0, precision=0, label="Bottom Shift", interactive=True)
201
+
202
+
203
+
204
+
205
+ # gr.Markdown(help_text)
206
+
207
+ invert_button.click(
208
+ fn=invert,
209
+ inputs=[input_image,
210
+ src_prompt,
211
+ steps,
212
+ src_cfg_scale,
213
+ left,
214
+ right,
215
+ top,
216
+ bottom
217
+ ],
218
+ outputs = [],
219
+ )
220
+
221
+ edit_button.click(
222
+ fn=edit,
223
+ inputs=[tar_prompt,
224
+ steps,
225
+ skip,
226
+ tar_cfg_scale,
227
+ ],
228
+ outputs=[output_image],
229
+ )
230
+
231
+
232
+
233
+
234
+ input_image.change(
235
+ fn = reset
236
+ )
237
+
238
+ # gr.Examples(
239
+ # label='Examples',
240
+ # examples=get_example(),
241
+ # inputs=[input_image, src_prompt, tar_prompt, steps,
242
+ # # src_cfg_scale,
243
+ # skip,
244
+ # tar_cfg_scale,
245
+ # edit_concept,
246
+ # sega_edit_guidance,
247
+ # warm_up,
248
+ # # neg_guidance,
249
+ # ddpm_edited_image, sega_edited_image
250
+ # ],
251
+ # outputs=[ddpm_edited_image, sega_edited_image],
252
+ # # fn=edit,
253
+ # # cache_examples=True
254
+ # )
255
+
256
+
257
+
258
+ demo.queue()
259
+ demo.launch(share=False)