Linoy Tsaban commited on
Commit
3489b04
1 Parent(s): bf5aed6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -15
app.py CHANGED
@@ -52,41 +52,119 @@ sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "schedule
52
  sem_pipe = SemanticStableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
53
 
54
 
55
- def edit(input_image, input_image_prompt='', target_prompt='', edit_prompt='',
56
- negative_guidance = False, edit_warmup_steps=5,
57
- edit_guidance_scale=8, guidance_scale=15, skip=36, num_diffusion_steps=100,
58
- ):
 
 
 
 
 
 
 
59
  offsets=(0,0,0,0)
60
  x0 = load_512(input_image, *offsets, device)
61
 
62
 
63
  # invert
64
- wt, zs, wts = invert(x0 =x0 , prompt_src=input_image_prompt, num_diffusion_steps=num_diffusion_steps)
65
  latnets = wts[skip].expand(1, -1, -1, -1)
66
 
67
  eta = 1
68
  #pure DDPM output
69
- pure_ddpm_out = sample(wt, zs, wts, prompt_tar=target_prompt,
70
- cfg_scale_tar=guidance_scale, skip=skip,
71
  eta = eta)
72
 
73
  editing_args = dict(
74
- editing_prompt = [edit_prompt],
75
- reverse_editing_direction = [negative_guidance],
76
- edit_warmup_steps=[edit_warmup_steps],
77
- edit_guidance_scale=[edit_guidance_scale],
78
  edit_threshold=[.93],
79
  edit_momentum_scale=0.5,
80
  edit_mom_beta=0.6
81
  )
82
- sega_out = sem_pipe(prompt=target_prompt,eta=eta, latents=latnets,
83
  num_images_per_prompt=1,
84
- num_inference_steps=num_diffusion_steps,
85
  use_ddpm=True, wts=wts, zs=zs[skip:], **editing_args)
86
  return pure_ddpm_out,sega_out.images[0]
87
 
88
 
89
- # See the gradio docs for the types of inputs and outputs available
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  inputs = [
91
  gr.Image(label="input image", shape=(512, 512)),
92
  gr.Textbox(label="input prompt"),
@@ -109,4 +187,6 @@ demo = gr.Interface(
109
  inputs=inputs,
110
  outputs=outputs,
111
  )
112
- demo.launch() # debug=True allows you to see errors and output in Colab
 
 
 
52
  sem_pipe = SemanticStableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
53
 
54
 
55
+ def edit(input_image,
56
+ src_prompt,
57
+ tar_prompt,
58
+ steps,
59
+ src_cfg_scale,
60
+ skip,
61
+ tar_cfg_scale,
62
+ edit_concept,
63
+ sega_edit_guidance,
64
+ warm_up,
65
+ neg_guidance):
66
  offsets=(0,0,0,0)
67
  x0 = load_512(input_image, *offsets, device)
68
 
69
 
70
  # invert
71
+ wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
72
  latnets = wts[skip].expand(1, -1, -1, -1)
73
 
74
  eta = 1
75
  #pure DDPM output
76
+ pure_ddpm_out = sample(wt, zs, wts, prompt_tar=tar_prompt,
77
+ cfg_scale_tar=tar_cfg_scale, skip=skip,
78
  eta = eta)
79
 
80
  editing_args = dict(
81
+ editing_prompt = [edit_concept],
82
+ reverse_editing_direction = [neg_guidance],
83
+ edit_warmup_steps=[warm_up],
84
+ edit_guidance_scale=[sega_edit_guidance],
85
  edit_threshold=[.93],
86
  edit_momentum_scale=0.5,
87
  edit_mom_beta=0.6
88
  )
89
+ sega_out = sem_pipe(prompt=tar_prompt,eta=eta, latents=latnets,
90
  num_images_per_prompt=1,
91
+ num_inference_steps=steps,
92
  use_ddpm=True, wts=wts, zs=zs[skip:], **editing_args)
93
  return pure_ddpm_out,sega_out.images[0]
94
 
95
 
96
+ ####################################3
97
+
98
+ with gr.Blocks() as demo:
99
+ gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">
100
+ Edit Friendly DDPM X Semantic Guidance: Editing Real Images
101
+ </h1>
102
+ <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
103
+ <br/>
104
+ <a href="https://huggingface.co/spaces/LinoyTsaban/ddpm_sega?duplicate=true">
105
+ <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
106
+ <p/>""")
107
+ with gr.Row():
108
+ with gr.Column(scale=1, min_width=100):
109
+ generate_button = gr.Button("Generate")
110
+ # with gr.Column(scale=1, min_width=100):
111
+ # reset_button = gr.Button("Reset")
112
+ # with gr.Column(scale=3):
113
+ # instruction = gr.Textbox(lines=1, label="Edit Instruction", interactive=True)
114
+
115
+ with gr.Row():
116
+ input_image = gr.Image(label="Input Image", type="pil", interactive=True)
117
+ ddpm_edited_image = gr.Image(label=f"Reconstructed Image", type="pil", interactive=False)
118
+ sega_edited_image = gr.Image(label=f"Edited Image", type="pil", interactive=False)
119
+ input_image.style(height=512, width=512)
120
+ ddpm_edited_image.style(height=512, width=512)
121
+ sega_edited_image.style(height=512, width=512)
122
+
123
+ with gr.Row():
124
+ src_prompt = gr.Textbox(lines=1, label="Source Prompt", interactive=True)
125
+ #edit
126
+ tar_prompt = gr.Textbox(lines=1, label="Target Prompt", interactive=True)
127
+
128
+ with gr.Row():
129
+ #inversion
130
+ steps = gr.Number(value=100, precision=0, label="Steps", interactive=True)
131
+ src_cfg_scale = gr.Number(value=3.5, label=f"Source CFG", interactive=True)
132
+ # reconstruction
133
+ skip = gr.Number(value=100, precision=0, label="Skip", interactive=True)
134
+ tar_cfg_scale = gr.Number(value=15, label=f"Reconstruction CFG", interactive=True)
135
+ # edit
136
+ edit_concept = gr.Textbox(lines=1, label="Edit Concept", interactive=True)
137
+ sega_edit_guidance = gr.Number(value=5, label=f"SEGA CFG", interactive=True)
138
+ warm_up = gr.Number(value=5, label=f"Warm-up Steps", interactive=True)
139
+ neg_guidance = gr.Checkbox(label="SEGA negative_guidance")
140
+
141
+
142
+ gr.Markdown(help_text)
143
+
144
+ generate_button.click(
145
+ fn=edit,
146
+ inputs=[input_image,
147
+ src_prompt,
148
+ tar_prompt,
149
+ steps,
150
+ src_cfg_scale,
151
+ skip,
152
+ tar_cfg_scale,
153
+ edit_concept,
154
+ sega_edit_guidance,
155
+ warm_up,
156
+ neg_guidance
157
+ ],
158
+ outputs=[input_image, ddpm_edited_image, sega_edited_image],
159
+ )
160
+
161
+
162
+ demo.queue(concurrency_count=1)
163
+ demo.launch(share=False)
164
+ ######################################################
165
+
166
+
167
+
168
  inputs = [
169
  gr.Image(label="input image", shape=(512, 512)),
170
  gr.Textbox(label="input prompt"),
 
187
  inputs=inputs,
188
  outputs=outputs,
189
  )
190
+ demo.launch() # debug=True allows you to see errors and output in Colab
191
+
192
+