Spaces:
Runtime error
Runtime error
add SD inpainting
Browse files
app.py
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
-
import os
|
2 |
-
from diffusers import StableDiffusionInpaintPipeline
|
3 |
-
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
4 |
import gradio as gr
|
5 |
-
from PIL import Image
|
6 |
import torch
|
7 |
import matplotlib.pyplot as plt
|
8 |
import cv2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
|
@@ -15,22 +18,66 @@ clip_seg_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refine
|
|
15 |
clip_seg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
16 |
sd_inpainting_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16, use_auth_token=auth_token).to(device)
|
17 |
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# predict
|
23 |
with torch.no_grad():
|
24 |
outputs = clip_seg_model(**inputs)
|
25 |
preds = outputs.logits
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
|
|
|
|
|
|
32 |
|
33 |
-
return mask_image
|
34 |
|
35 |
|
36 |
|
@@ -44,7 +91,7 @@ interface = gr.Interface(fn=process_image,
|
|
44 |
inputs=[
|
45 |
gr.Image(type="pil"),
|
46 |
gr.Textbox(label="What to identify"),
|
47 |
-
gr.Textbox(label="What to replace"),
|
48 |
],
|
49 |
outputs=[
|
50 |
gr.Image(type="pil"),
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import torch
|
3 |
import matplotlib.pyplot as plt
|
4 |
import cv2
|
5 |
+
import os
|
6 |
+
|
7 |
+
from diffusers import StableDiffusionInpaintPipeline
|
8 |
+
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
9 |
+
from PIL import Image
|
10 |
+
from torch.cuda.amp import autocast
|
11 |
+
|
12 |
|
13 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
|
|
|
18 |
clip_seg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
19 |
sd_inpainting_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16, use_auth_token=auth_token).to(device)
|
20 |
|
21 |
+
WIDTH=512
|
22 |
+
HEIGHT=512
|
23 |
+
DILATE=10
|
24 |
+
THRESHOLDS=0.1
|
25 |
|
26 |
+
|
27 |
+
def dilate_mask(mask_file):
|
28 |
+
image = cv2.imread(mask_file, 0)
|
29 |
+
kernel = np.ones((DILATE, DILATE), np.uint8)
|
30 |
+
dilated = cv2.dilate(image, kernel, iterations=1)
|
31 |
+
im_bin = (dilated > 127) * 255
|
32 |
+
cv2.imwrite(mask_file, im_bin)
|
33 |
+
return mask_file
|
34 |
+
|
35 |
+
def process_mask(prompt_find, image, THRESHOLDS=0.1):
|
36 |
+
inputs = clip_seg_processor(
|
37 |
+
text=prompt_find,
|
38 |
+
images=image,
|
39 |
+
padding="max_length",
|
40 |
+
return_tensors="pt"
|
41 |
+
)
|
42 |
|
43 |
# predict
|
44 |
with torch.no_grad():
|
45 |
outputs = clip_seg_model(**inputs)
|
46 |
preds = outputs.logits
|
47 |
|
48 |
+
out_img = torch.sigmoid(preds)
|
49 |
+
out_img = (out_img - out_img.min()) / out_img.max()
|
50 |
+
if isinstance(THRESHOLDS, list):
|
51 |
+
if len(THRESHOLDS) >= 2:
|
52 |
+
out_img = torch.where(out_img >= THRESHOLDS[1], 1., out_img)
|
53 |
+
out_img = torch.where(out_img <= THRESHOLDS[0], 0., out_img)
|
54 |
+
else:
|
55 |
+
out_img = torch.where(out_img >= THRESHOLDS[0], 1., 0.)
|
56 |
+
else:
|
57 |
+
out_img = torch.where(out_img >= THRESHOLDS, 1., 0.)
|
58 |
+
|
59 |
+
mask_file="mask.png"
|
60 |
+
plt.imsave(mask_file, out_img)
|
61 |
+
dilated_mask = dilate_mask(mask_file)
|
62 |
+
|
63 |
+
mask_image = Image.open(dilated_mask)
|
64 |
+
|
65 |
+
return mask_image
|
66 |
+
|
67 |
+
def process_inpaint(prompt_replace, image, mask_image):
|
68 |
+
image = sd_inpainting_pipe(
|
69 |
+
prompt=prompt_replace,
|
70 |
+
image=image,
|
71 |
+
mask_image=mask_image
|
72 |
+
).images[0]
|
73 |
+
return image
|
74 |
|
75 |
+
def process_image(image, prompt_find, prompt_replace):
|
76 |
+
orig_image = image.resize((WIDTH, HEIGHT))
|
77 |
+
mask_image = process_mask(prompt_find, orig_image).resize((WIDTH, HEIGHT))
|
78 |
+
new_image = process_inpaint(prompt_replace, orig_image, mask_image)
|
79 |
|
80 |
+
return new_image, mask_image
|
81 |
|
82 |
|
83 |
|
|
|
91 |
inputs=[
|
92 |
gr.Image(type="pil"),
|
93 |
gr.Textbox(label="What to identify"),
|
94 |
+
gr.Textbox(label="What to replace it with"),
|
95 |
],
|
96 |
outputs=[
|
97 |
gr.Image(type="pil"),
|