Spaces:
Build error
Build error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from diffusers import StableDiffusionInpaintPipeline
|
8 |
+
|
9 |
+
auth_token = os.getenv("auth_token")
|
10 |
+
|
11 |
+
|
12 |
+
def preview(image, state):
|
13 |
+
h, w = image.shape[:2]
|
14 |
+
scale_percent = 512 / max([w, h])
|
15 |
+
|
16 |
+
width = int(w * scale_percent)
|
17 |
+
height = int(h * scale_percent)
|
18 |
+
dim = (width, height)
|
19 |
+
resized = cv2.resize(image, dim, interpolation=cv2.INTER_AREA)
|
20 |
+
yoff = round((512-height)/2)
|
21 |
+
xoff = round((512-width)/2)
|
22 |
+
|
23 |
+
final_image = np.zeros((512, 512, 3), dtype=np.uint8)
|
24 |
+
final_image.fill(120)
|
25 |
+
final_image[yoff:yoff+height, xoff:xoff+width, :] = resized
|
26 |
+
|
27 |
+
mask_image = np.zeros((512, 512, 3), dtype=np.uint8)
|
28 |
+
mask_image.fill(255)
|
29 |
+
mask_image[yoff:yoff+height, xoff:xoff+width, :] = 0
|
30 |
+
state.clear()
|
31 |
+
state.append(mask_image)
|
32 |
+
state.append([yoff, xoff, height, width])
|
33 |
+
state.append(image)
|
34 |
+
|
35 |
+
return final_image, state
|
36 |
+
|
37 |
+
|
38 |
+
def sd_inpaint(image, prompt, state):
|
39 |
+
mask = state[0]
|
40 |
+
yoff, xoff, height, width = state[1]
|
41 |
+
orig_image = state[2]
|
42 |
+
|
43 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
44 |
+
|
45 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
46 |
+
"runwayml/stable-diffusion-inpainting",
|
47 |
+
revision="fp16",
|
48 |
+
torch_dtype=torch.float16,
|
49 |
+
use_auth_token=auth_token
|
50 |
+
).to(device)
|
51 |
+
|
52 |
+
output = pipe(prompt=prompt, image=Image.fromarray(image), mask_image=Image.fromarray(mask)).images[0]
|
53 |
+
result = np.array(output)
|
54 |
+
result[yoff:yoff+height, xoff:xoff+width, :] = orig_image
|
55 |
+
result = Image.fromarray(result)
|
56 |
+
|
57 |
+
return result
|
58 |
+
|
59 |
+
|
60 |
+
with gr.Blocks(title='Dreambooth Image Editing and Stable Diffusion Inpainting') as demo:
|
61 |
+
state = gr.State([])
|
62 |
+
gr.Markdown("# Dreambooth Image Editing and Stable Diffusion Inpainting")
|
63 |
+
gr.Markdown("It's difficult to get a good image to use for dreambooth, I do not have many photograhps of myself alone and it's very slow to edit the images (crop the selection, scale it to 512x512 and solve the problem of the background somehow)")
|
64 |
+
gr.Markdown("This app uses a combination of image selection, automatic scaling, and stable diffusion inpainting to speed that process. Follow the next instructions:")
|
65 |
+
gr.Markdown("""- Upload an image
|
66 |
+
- Use the select tool to select the area you want to use for dreambooth
|
67 |
+
- The image will be resized to 512x512 and fill the rest of with a gray background
|
68 |
+
- Then click the Inpaint button to use stable diffusion to inpaint the background
|
69 |
+
- Save the image and use it for dreambooth
|
70 |
+
""")
|
71 |
+
with gr.Row():
|
72 |
+
with gr.Column():
|
73 |
+
img_ctr = gr.Image(tool='select')
|
74 |
+
with gr.Column():
|
75 |
+
output = gr.Image()
|
76 |
+
with gr.Row():
|
77 |
+
greet_btn = gr.Button("Selection")
|
78 |
+
with gr.Row():
|
79 |
+
sd_prompt = gr.Textbox(lines=2, label="Stable diffusion prompt")
|
80 |
+
with gr.Row():
|
81 |
+
final_image = gr.Image()
|
82 |
+
with gr.Row():
|
83 |
+
stab_btn = gr.Button("Inpaint")
|
84 |
+
|
85 |
+
greet_btn.click(fn=preview, inputs=[img_ctr, state], outputs=[output, state])
|
86 |
+
stab_btn.click(fn=sd_inpaint, inputs=[output, sd_prompt, state], outputs=final_image)
|
87 |
+
|
88 |
+
|
89 |
+
demo.launch()
|