Spaces:
Runtime error
Runtime error
test
Browse files
app.py
CHANGED
|
@@ -160,6 +160,16 @@ def inpaint_image(image, prompt, subject, editor_value):
|
|
| 160 |
target_text_prompt=prompt
|
| 161 |
prompt_final=f'A two side-by-side image of same {subject_name}. LEFT: a photo of the {subject_name}; RIGHT: a photo of the {subject_name} {target_text_prompt}.'
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
# Inpaint
|
| 164 |
result = pipe(
|
| 165 |
prompt=prompt_final,
|
|
@@ -175,7 +185,7 @@ def inpaint_image(image, prompt, subject, editor_value):
|
|
| 175 |
true_guidance_scale=1.0,
|
| 176 |
attn_scale_mask=full_attn_scale_mask,
|
| 177 |
).images[0]
|
| 178 |
-
return result,
|
| 179 |
|
| 180 |
# Create Gradio interface with structured layout
|
| 181 |
with gr.Blocks() as iface:
|
|
@@ -192,12 +202,12 @@ with gr.Blocks() as iface:
|
|
| 192 |
with gr.Column():
|
| 193 |
editor_value = gr.ImageEditor(type="pil", label="Image with Mask", sources="upload", visible=False)
|
| 194 |
inpainted_image = gr.Image(type="pil", label="Inpainted Image")
|
| 195 |
-
|
| 196 |
with gr.Row():
|
| 197 |
|
| 198 |
inpaint_button = gr.Button("Inpaint")
|
| 199 |
|
| 200 |
-
inpaint_button.click(fn=inpaint_image, inputs=[input_image, prompt, subject, editor_value], outputs=[inpainted_image,
|
| 201 |
|
| 202 |
# Launch the app
|
| 203 |
iface.launch()
|
|
|
|
| 160 |
target_text_prompt=prompt
|
| 161 |
prompt_final=f'A two side-by-side image of same {subject_name}. LEFT: a photo of the {subject_name}; RIGHT: a photo of the {subject_name} {target_text_prompt}.'
|
| 162 |
|
| 163 |
+
# Convert attention mask to PIL image format
|
| 164 |
+
# Take first head's mask after prompt tokens (shape is now H*W x H*W)
|
| 165 |
+
attn_vis = full_attn_scale_mask[0, 0]
|
| 166 |
+
attn_vis[attn_vis <= 1.0] = 0
|
| 167 |
+
attn_vis[attn_vis > 1.0] = 255
|
| 168 |
+
attn_vis = attn_vis.cpu().float().numpy().astype(np.uint8)
|
| 169 |
+
# # Convert to PIL Image
|
| 170 |
+
attn_vis_img = Image.fromarray(attn_vis)
|
| 171 |
+
attn_vis_img.save('attention_mask_vis.png')
|
| 172 |
+
|
| 173 |
# Inpaint
|
| 174 |
result = pipe(
|
| 175 |
prompt=prompt_final,
|
|
|
|
| 185 |
true_guidance_scale=1.0,
|
| 186 |
attn_scale_mask=full_attn_scale_mask,
|
| 187 |
).images[0]
|
| 188 |
+
return result, attn_vis_img
|
| 189 |
|
| 190 |
# Create Gradio interface with structured layout
|
| 191 |
with gr.Blocks() as iface:
|
|
|
|
| 202 |
with gr.Column():
|
| 203 |
editor_value = gr.ImageEditor(type="pil", label="Image with Mask", sources="upload", visible=False)
|
| 204 |
inpainted_image = gr.Image(type="pil", label="Inpainted Image")
|
| 205 |
+
attn_vis_img = gr.Image(type="pil", label="Attn Vis Image")
|
| 206 |
with gr.Row():
|
| 207 |
|
| 208 |
inpaint_button = gr.Button("Inpaint")
|
| 209 |
|
| 210 |
+
inpaint_button.click(fn=inpaint_image, inputs=[input_image, prompt, subject, editor_value], outputs=[inpainted_image, attn_vis_img])
|
| 211 |
|
| 212 |
# Launch the app
|
| 213 |
iface.launch()
|