alaaawad commited on
Commit
e895c3b
1 Parent(s): 595105c

add SD inpainting

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +13 -3
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  .DS_Store
 
2
  *.pth
3
  # Byte-compiled / optimized / DLL files
4
  __pycache__/
 
1
  .DS_Store
2
+ mask.png
3
  *.pth
4
  # Byte-compiled / optimized / DLL files
5
  __pycache__/
app.py CHANGED
@@ -7,11 +7,17 @@ import torch
7
  import matplotlib.pyplot as plt
8
  import cv2
9
 
 
 
 
 
10
  clip_seg_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
11
  clip_seg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
 
 
12
 
13
- def process_image(image, prompt):
14
- inputs = clip_seg_processor(text=prompt, images=image, padding="max_length", return_tensors="pt")
15
 
16
  # predict
17
  with torch.no_grad():
@@ -22,7 +28,9 @@ def process_image(image, prompt):
22
  plt.imsave(filename_mask, torch.sigmoid(preds))
23
  mask_image = Image.open(filename_mask).convert("RGB")
24
 
25
- return mask_image
 
 
26
 
27
 
28
 
@@ -36,9 +44,11 @@ interface = gr.Interface(fn=process_image,
36
  inputs=[
37
  gr.Image(type="pil"),
38
  gr.Textbox(label="What to identify"),
 
39
  ],
40
  outputs=[
41
  gr.Image(type="pil"),
 
42
  ],
43
  title=title,
44
  description=description,
 
7
  import matplotlib.pyplot as plt
8
  import cv2
9
 
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+
13
+ auth_token = os.environ.get("HF_TOKEN") or True
14
  clip_seg_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
15
  clip_seg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
16
+ sd_inpainting_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16, use_auth_token=auth_token).to(device)
17
+
18
 
19
+ def process_image(image, prompt_find, prompt_replace):
20
+ inputs = clip_seg_processor(text=prompt_find, images=image, padding="max_length", return_tensors="pt")
21
 
22
  # predict
23
  with torch.no_grad():
 
28
  plt.imsave(filename_mask, torch.sigmoid(preds))
29
  mask_image = Image.open(filename_mask).convert("RGB")
30
 
31
+ image = sd_inpainting_pipe(prompt=prompt_replace, image=image, mask_image=mask_image).images[0]
32
+
33
+ return mask_image,image
34
 
35
 
36
 
 
44
  inputs=[
45
  gr.Image(type="pil"),
46
  gr.Textbox(label="What to identify"),
47
+ gr.Textbox(label="What to replace"),
48
  ],
49
  outputs=[
50
  gr.Image(type="pil"),
51
+ gr.Image(type="pil"),
52
  ],
53
  title=title,
54
  description=description,