File size: 3,574 Bytes
f4e3685
 
 
 
aa7dbbe
 
 
 
 
 
 
f4e3685
e895c3b
 
 
 
98bd5e6
 
e895c3b
 
aa7dbbe
 
 
 
f4e3685
aa7dbbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4e3685
 
 
98bd5e6
f4e3685
 
aa7dbbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98bd5e6
aa7dbbe
 
 
 
e895c3b
aa7dbbe
f4e3685
 
 
98bd5e6
f4e3685
98bd5e6
f4e3685
 
 
 
98bd5e6
 
 
aa7dbbe
98bd5e6
 
 
e895c3b
98bd5e6
f4e3685
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import torch
import matplotlib.pyplot as plt
import cv2
import os

from diffusers import StableDiffusionInpaintPipeline
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from PIL import Image
from torch.cuda.amp import autocast


device = "cuda" if torch.cuda.is_available() else "cpu"


auth_token = os.environ.get("HF_TOKEN") or True
clip_seg_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
clip_seg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
sd_inpainting_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16, use_auth_token=auth_token).to(device)

WIDTH=512
HEIGHT=512
DILATE=10
THRESHOLDS=0.1


def dilate_mask(mask_file):
    image = cv2.imread(mask_file, 0)
    kernel  = np.ones((DILATE, DILATE), np.uint8)
    dilated = cv2.dilate(image, kernel, iterations=1)
    im_bin = (dilated > 127) * 255
    cv2.imwrite(mask_file, im_bin)
    return mask_file

def process_mask(prompt_find, image, THRESHOLDS=0.1):
  inputs = clip_seg_processor(
    text=prompt_find,
    images=image,
    padding="max_length",
    return_tensors="pt"
  )

  # predict
  with torch.no_grad():
    outputs = clip_seg_model(**inputs)
    preds = outputs.logits

  out_img = torch.sigmoid(preds)
  out_img = (out_img - out_img.min()) / out_img.max()
  if isinstance(THRESHOLDS, list):
    if len(THRESHOLDS) >= 2:
      out_img = torch.where(out_img >= THRESHOLDS[1], 1., out_img)
      out_img = torch.where(out_img <= THRESHOLDS[0], 0., out_img)
    else:
      out_img = torch.where(out_img >= THRESHOLDS[0], 1., 0.)
  else:
    out_img = torch.where(out_img >= THRESHOLDS, 1., 0.)

  mask_file="mask.png"
  plt.imsave(mask_file, out_img)
  dilated_mask = dilate_mask(mask_file)

  mask_image = Image.open(dilated_mask)

  return mask_image

def process_inpaint(prompt_replace, image, mask_image):
  image = sd_inpainting_pipe(
    prompt=prompt_replace,
    image=image,
    mask_image=mask_image
  ).images[0]
  return image

def process_image(image, prompt_find, prompt_replace):
  orig_image = image.resize((WIDTH, HEIGHT))
  mask_image = process_mask(prompt_find, orig_image).resize((WIDTH, HEIGHT))
  new_image = process_inpaint(prompt_replace, orig_image, mask_image)

  return new_image, mask_image



title = "Interactive demo: Prompt based inPainting using CLIPSeg x Stable Diffusion"

description = "Demo for prompt based inPainting. It uses CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. Once it identifies the image segment based on a text mask, or use one of the examples below and click 'submit'. Results will show up in a few seconds."

article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a> | <a href='https://huggingface.co/docs/transformers/main/en/model_doc/clipseg'>HuggingFace docs</a></p>"

interface = gr.Interface(fn=process_image,
                     inputs=[
                        gr.Image(type="pil"),
                        gr.Textbox(label="What to identify"),
                        gr.Textbox(label="What to replace it with"),
                      ],
                     outputs=[
                        gr.Image(type="pil"),
                        gr.Image(type="pil"),
                     ],
                     title=title,
                     description=description,
                     article=article)

interface.launch(debug=True)