Spaces:
Runtime error
Runtime error
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation | |
from diffusers import StableDiffusionInpaintPipeline as StableDiffusionInpaintPipeline | |
import gradio as gr | |
from PIL import Image | |
import torch | |
from torch import autocast | |
import matplotlib.pyplot as plt | |
import numpy as np | |
auth_token = os.environ.get("API_TOKEN") or True | |
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-inpainting", | |
revision="fp16", | |
torch_dtype=torch.float16, | |
use_auth_token=auth_token, | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = pipe.to(device) | |
def pad_image(image): | |
w, h = image.size | |
if w == h: | |
return image | |
elif w > h: | |
new_image = Image.new(image.mode, (w, w), (0, 0, 0)) | |
new_image.paste(image, (0, (w - h) // 2)) | |
return new_image | |
else: | |
new_image = Image.new(image.mode, (h, h), (0, 0, 0)) | |
new_image.paste(image, ((h - w) // 2, 0)) | |
return new_image | |
def process_image(image, prompt): | |
inputs = processor( | |
text=prompt, images=image, padding="max_length", return_tensors="pt" | |
) | |
# predict | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
preds = outputs.logits | |
pred = torch.sigmoid(preds) | |
mat = pred.cpu().numpy() | |
mask = Image.fromarray(np.uint8(mat * 255), "L") | |
mask = mask.convert("RGB") | |
mask = mask.resize(image.size) | |
mask = np.array(mask)[:, :, 0] | |
# normalize the mask | |
mask_min = mask.min() | |
mask_max = mask.max() | |
mask = (mask - mask_min) / (mask_max - mask_min) | |
return mask | |
def get_masks(prompts, img, threhsold): | |
prompts = prompts.split(",") | |
masks = [] | |
for prompt in prompts: | |
mask = process_image(img, prompt) | |
mask = mask > threhsold | |
masks.append(mask) | |
return masks | |
def extract_image(img, pos_prompts, neg_prompts, threshold): | |
positive_masks = get_masks(pos_prompts, img, threshold) | |
negative_masks = get_masks(neg_prompts, img, threshold) | |
# combine masks into one masks, logic OR | |
pos_mask = np.any(np.stack(positive_masks), axis=0) | |
neg_mask = np.any(np.stack(negative_masks), axis=0) | |
final_mask = pos_mask & ~neg_mask | |
# extract the final image | |
final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L") | |
inverse_mask = np.invert(final_mask, dtype=np.uint8) | |
output_image = Image.new("RGBA", img.size, (0, 0, 0, 0)) | |
output_image.paste(img, mask=final_mask) | |
return output_image, final_mask, inverse_mask | |
def inpaint_image(img, mask, prompt): | |
img = pad_image(img).convert("RGB").resize((512, 512)) | |
mask = Image.fromarray(mask*255) | |
mask = pad_image(mask).convert("RGB").resize((512, 512)) | |
with torch.cuda.amp.autocast(True): | |
inpainted_image = pipe(prompt=prompt, image=img, mask_image=mask).images[0] | |
return inpainted_image | |
title = "Interactive demo: zero-shot image segmentation with CLIPSeg" | |
description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'. Results will show up in a few seconds." | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a> | <a href='https://huggingface.co/docs/transformers/main/en/model_doc/clipseg'>HuggingFace docs</a></p>" | |
with gr.Blocks() as demo: | |
gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts") | |
gr.Markdown(article) | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil") | |
positive_prompts = gr.Textbox( | |
label="Please describe what you want to identify (comma separated)" | |
) | |
negative_prompts = gr.Textbox( | |
label="Please describe what you want to ignore (comma separated)" | |
) | |
input_slider_T = gr.Slider( | |
minimum=0, maximum=1, value=0.4, label="Threshold" | |
) | |
btn_mask = gr.Button(label="Mask") | |
with gr.Column(): | |
output_image = gr.Image(label="Result") | |
output_mask = gr.Image(label="Mask") | |
inverse_mask = gr.Image(label="Inverse") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="Prompt" | |
) | |
input_slider_S = gr.Slider( | |
minimum=0, maximum=1, value=0.4, label="Image Strength" | |
) | |
btn_run = gr.Button(label="Run") | |
with gr.Column(): | |
inpainted_image = gr.Image(label="Inpainted") | |
btn_mask.click( | |
extract_image, | |
inputs=[ | |
input_image, | |
positive_prompts, | |
negative_prompts, | |
input_slider_T, | |
], | |
outputs=[output_image, output_mask, inverse_mask], | |
api_name="mask" | |
) | |
btn_run.click( | |
inpaint_image, | |
inputs=[ | |
input_image, | |
inverse_mask, | |
prompt, | |
], | |
outputs=[inpainted_image], | |
api_name="run" | |
) | |
demo.launch() |