alaaawad commited on
Commit
aa7dbbe
1 Parent(s): e895c3b

add SD inpainting

Browse files
Files changed (1) hide show
  1. app.py +59 -12
app.py CHANGED
@@ -1,11 +1,14 @@
1
- import os
2
- from diffusers import StableDiffusionInpaintPipeline
3
- from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
4
  import gradio as gr
5
- from PIL import Image
6
  import torch
7
  import matplotlib.pyplot as plt
8
  import cv2
 
 
 
 
 
 
 
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
@@ -15,22 +18,66 @@ clip_seg_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refine
15
  clip_seg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
16
  sd_inpainting_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16, use_auth_token=auth_token).to(device)
17
 
 
 
 
 
18
 
19
- def process_image(image, prompt_find, prompt_replace):
20
- inputs = clip_seg_processor(text=prompt_find, images=image, padding="max_length", return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # predict
23
  with torch.no_grad():
24
  outputs = clip_seg_model(**inputs)
25
  preds = outputs.logits
26
 
27
- filename_mask = f"mask.png"
28
- plt.imsave(filename_mask, torch.sigmoid(preds))
29
- mask_image = Image.open(filename_mask).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- image = sd_inpainting_pipe(prompt=prompt_replace, image=image, mask_image=mask_image).images[0]
 
 
 
32
 
33
- return mask_image,image
34
 
35
 
36
 
@@ -44,7 +91,7 @@ interface = gr.Interface(fn=process_image,
44
  inputs=[
45
  gr.Image(type="pil"),
46
  gr.Textbox(label="What to identify"),
47
- gr.Textbox(label="What to replace"),
48
  ],
49
  outputs=[
50
  gr.Image(type="pil"),
 
 
 
 
1
  import gradio as gr
 
2
  import torch
3
  import matplotlib.pyplot as plt
4
  import cv2
5
+ import os
6
+
7
+ from diffusers import StableDiffusionInpaintPipeline
8
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
9
+ from PIL import Image
10
+ from torch.cuda.amp import autocast
11
+
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
 
18
  clip_seg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
19
  sd_inpainting_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16, use_auth_token=auth_token).to(device)
20
 
21
+ WIDTH=512
22
+ HEIGHT=512
23
+ DILATE=10
24
+ THRESHOLDS=0.1
25
 
26
+
27
+ def dilate_mask(mask_file):
28
+ image = cv2.imread(mask_file, 0)
29
+ kernel = np.ones((DILATE, DILATE), np.uint8)
30
+ dilated = cv2.dilate(image, kernel, iterations=1)
31
+ im_bin = (dilated > 127) * 255
32
+ cv2.imwrite(mask_file, im_bin)
33
+ return mask_file
34
+
35
+ def process_mask(prompt_find, image, THRESHOLDS=0.1):
36
+ inputs = clip_seg_processor(
37
+ text=prompt_find,
38
+ images=image,
39
+ padding="max_length",
40
+ return_tensors="pt"
41
+ )
42
 
43
  # predict
44
  with torch.no_grad():
45
  outputs = clip_seg_model(**inputs)
46
  preds = outputs.logits
47
 
48
+ out_img = torch.sigmoid(preds)
49
+ out_img = (out_img - out_img.min()) / out_img.max()
50
+ if isinstance(THRESHOLDS, list):
51
+ if len(THRESHOLDS) >= 2:
52
+ out_img = torch.where(out_img >= THRESHOLDS[1], 1., out_img)
53
+ out_img = torch.where(out_img <= THRESHOLDS[0], 0., out_img)
54
+ else:
55
+ out_img = torch.where(out_img >= THRESHOLDS[0], 1., 0.)
56
+ else:
57
+ out_img = torch.where(out_img >= THRESHOLDS, 1., 0.)
58
+
59
+ mask_file="mask.png"
60
+ plt.imsave(mask_file, out_img)
61
+ dilated_mask = dilate_mask(mask_file)
62
+
63
+ mask_image = Image.open(dilated_mask)
64
+
65
+ return mask_image
66
+
67
+ def process_inpaint(prompt_replace, image, mask_image):
68
+ image = sd_inpainting_pipe(
69
+ prompt=prompt_replace,
70
+ image=image,
71
+ mask_image=mask_image
72
+ ).images[0]
73
+ return image
74
 
75
+ def process_image(image, prompt_find, prompt_replace):
76
+ orig_image = image.resize((WIDTH, HEIGHT))
77
+ mask_image = process_mask(prompt_find, orig_image).resize((WIDTH, HEIGHT))
78
+ new_image = process_inpaint(prompt_replace, orig_image, mask_image)
79
 
80
+ return new_image, mask_image
81
 
82
 
83
 
 
91
  inputs=[
92
  gr.Image(type="pil"),
93
  gr.Textbox(label="What to identify"),
94
+ gr.Textbox(label="What to replace it with"),
95
  ],
96
  outputs=[
97
  gr.Image(type="pil"),