clipseg / app.py
shengqiangShi's picture
Update app.py
1e759df verified
raw
history blame contribute delete
No virus
1.91 kB
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import gradio as gr
from PIL import Image
import torch
import matplotlib.pyplot as plt
import numpy as np
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
def process_image(image, prompt):
# Prepare inputs with the processor
inputs = processor(text=prompt, images=image, return_tensors="pt")
# Predict
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits.squeeze() # Assuming the output logits is of shape [1, H, W]
# Apply sigmoid to convert logits to probabilities
preds = torch.sigmoid(preds)
# Convert to numpy array
mask = preds.numpy()
# Save the image correctly handling dimensions
filename = "mask.png"
plt.imsave(filename, mask, cmap='gray') # Use cmap='gray' for grayscale image saving
# Convert to PIL Image and return
return Image.open(filename).convert("RGB")
title = "Interactive demo: zero-shot image segmentation with CLIPSeg"
description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a></p>"
examples = [["example_image.png", "a description of what to segment"]]
interface = gr.Interface(fn=process_image,
inputs=[gr.Image(type="pil"), gr.Textbox(label="Please describe what you want to identify")],
outputs=gr.Image(type="pil"),
title=title,
description=description,
article=article,
examples=examples)
interface.launch(debug=True)