sessex's picture
Update app.py
ce44e90
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from diffusers import StableDiffusionInpaintPipeline as StableDiffusionInpaintPipeline
import gradio as gr
from PIL import Image
import torch
from torch import autocast
import matplotlib.pyplot as plt
import numpy as np
auth_token = os.environ.get("API_TOKEN") or True
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=auth_token,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipe.to(device)
def pad_image(image):
w, h = image.size
if w == h:
return image
elif w > h:
new_image = Image.new(image.mode, (w, w), (0, 0, 0))
new_image.paste(image, (0, (w - h) // 2))
return new_image
else:
new_image = Image.new(image.mode, (h, h), (0, 0, 0))
new_image.paste(image, ((h - w) // 2, 0))
return new_image
def process_image(image, prompt):
inputs = processor(
text=prompt, images=image, padding="max_length", return_tensors="pt"
)
# predict
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits
pred = torch.sigmoid(preds)
mat = pred.cpu().numpy()
mask = Image.fromarray(np.uint8(mat * 255), "L")
mask = mask.convert("RGB")
mask = mask.resize(image.size)
mask = np.array(mask)[:, :, 0]
# normalize the mask
mask_min = mask.min()
mask_max = mask.max()
mask = (mask - mask_min) / (mask_max - mask_min)
return mask
def get_masks(prompts, img, threhsold):
prompts = prompts.split(",")
masks = []
for prompt in prompts:
mask = process_image(img, prompt)
mask = mask > threhsold
masks.append(mask)
return masks
def extract_image(img, pos_prompts, neg_prompts, threshold):
positive_masks = get_masks(pos_prompts, img, threshold)
negative_masks = get_masks(neg_prompts, img, threshold)
# combine masks into one masks, logic OR
pos_mask = np.any(np.stack(positive_masks), axis=0)
neg_mask = np.any(np.stack(negative_masks), axis=0)
final_mask = pos_mask & ~neg_mask
# extract the final image
final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
inverse_mask = np.invert(final_mask, dtype=np.uint8)
output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
output_image.paste(img, mask=final_mask)
return output_image, final_mask, inverse_mask
def inpaint_image(img, mask, prompt):
img = pad_image(img).convert("RGB").resize((512, 512))
mask = Image.fromarray(mask*255)
mask = pad_image(mask).convert("RGB").resize((512, 512))
with torch.cuda.amp.autocast(True):
inpainted_image = pipe(prompt=prompt, image=img, mask_image=mask).images[0]
return inpainted_image
title = "Interactive demo: zero-shot image segmentation with CLIPSeg"
description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'. Results will show up in a few seconds."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a> | <a href='https://huggingface.co/docs/transformers/main/en/model_doc/clipseg'>HuggingFace docs</a></p>"
with gr.Blocks() as demo:
gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
gr.Markdown(article)
gr.Markdown(description)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil")
positive_prompts = gr.Textbox(
label="Please describe what you want to identify (comma separated)"
)
negative_prompts = gr.Textbox(
label="Please describe what you want to ignore (comma separated)"
)
input_slider_T = gr.Slider(
minimum=0, maximum=1, value=0.4, label="Threshold"
)
btn_mask = gr.Button(label="Mask")
with gr.Column():
output_image = gr.Image(label="Result")
output_mask = gr.Image(label="Mask")
inverse_mask = gr.Image(label="Inverse")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt"
)
input_slider_S = gr.Slider(
minimum=0, maximum=1, value=0.4, label="Image Strength"
)
btn_run = gr.Button(label="Run")
with gr.Column():
inpainted_image = gr.Image(label="Inpainted")
btn_mask.click(
extract_image,
inputs=[
input_image,
positive_prompts,
negative_prompts,
input_slider_T,
],
outputs=[output_image, output_mask, inverse_mask],
api_name="mask"
)
btn_run.click(
inpaint_image,
inputs=[
input_image,
inverse_mask,
prompt,
],
outputs=[inpainted_image],
api_name="run"
)
demo.launch()