import gradio as gr import torch import matplotlib.pyplot as plt import cv2 import os from diffusers import StableDiffusionInpaintPipeline from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation from PIL import Image from torch.cuda.amp import autocast device = "cuda" if torch.cuda.is_available() else "cpu" auth_token = os.environ.get("HF_TOKEN") or True clip_seg_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") clip_seg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") sd_inpainting_pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16, use_auth_token=auth_token).to(device) WIDTH=512 HEIGHT=512 DILATE=10 THRESHOLDS=0.1 def dilate_mask(mask_file): image = cv2.imread(mask_file, 0) kernel = np.ones((DILATE, DILATE), np.uint8) dilated = cv2.dilate(image, kernel, iterations=1) im_bin = (dilated > 127) * 255 cv2.imwrite(mask_file, im_bin) return mask_file def process_mask(prompt_find, image, THRESHOLDS=0.1): inputs = clip_seg_processor( text=prompt_find, images=image, padding="max_length", return_tensors="pt" ) # predict with torch.no_grad(): outputs = clip_seg_model(**inputs) preds = outputs.logits out_img = torch.sigmoid(preds) out_img = (out_img - out_img.min()) / out_img.max() if isinstance(THRESHOLDS, list): if len(THRESHOLDS) >= 2: out_img = torch.where(out_img >= THRESHOLDS[1], 1., out_img) out_img = torch.where(out_img <= THRESHOLDS[0], 0., out_img) else: out_img = torch.where(out_img >= THRESHOLDS[0], 1., 0.) else: out_img = torch.where(out_img >= THRESHOLDS, 1., 0.) mask_file="mask.png" plt.imsave(mask_file, out_img) dilated_mask = dilate_mask(mask_file) mask_image = Image.open(dilated_mask) return mask_image def process_inpaint(prompt_replace, image, mask_image): image = sd_inpainting_pipe( prompt=prompt_replace, image=image, mask_image=mask_image ).images[0] return image def process_image(image, prompt_find, prompt_replace): orig_image = image.resize((WIDTH, HEIGHT)) mask_image = process_mask(prompt_find, orig_image).resize((WIDTH, HEIGHT)) new_image = process_inpaint(prompt_replace, orig_image, mask_image) return new_image, mask_image title = "Interactive demo: Prompt based inPainting using CLIPSeg x Stable Diffusion" description = "Demo for prompt based inPainting. It uses CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. Once it identifies the image segment based on a text mask, or use one of the examples below and click 'submit'. Results will show up in a few seconds." article = "

CLIPSeg: Image Segmentation Using Text and Image Prompts | HuggingFace docs

" interface = gr.Interface(fn=process_image, inputs=[ gr.Image(type="pil"), gr.Textbox(label="What to identify"), gr.Textbox(label="What to replace it with"), ], outputs=[ gr.Image(type="pil"), gr.Image(type="pil"), ], title=title, description=description, article=article) interface.launch(debug=True)