from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation import gradio as gr from PIL import Image import torch import matplotlib.pyplot as plt import numpy as np processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") def process_image(image, prompt): # Prepare inputs with the processor inputs = processor(text=prompt, images=image, return_tensors="pt") # Predict with torch.no_grad(): outputs = model(**inputs) preds = outputs.logits.squeeze() # Assuming the output logits is of shape [1, H, W] # Apply sigmoid to convert logits to probabilities preds = torch.sigmoid(preds) # Convert to numpy array mask = preds.numpy() # Save the image correctly handling dimensions filename = "mask.png" plt.imsave(filename, mask, cmap='gray') # Use cmap='gray' for grayscale image saving # Convert to PIL Image and return return Image.open(filename).convert("RGB") title = "Interactive demo: zero-shot image segmentation with CLIPSeg" description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation." article = "

CLIPSeg: Image Segmentation Using Text and Image Prompts

" examples = [["example_image.png", "a description of what to segment"]] interface = gr.Interface(fn=process_image, inputs=[gr.Image(type="pil"), gr.Textbox(label="Please describe what you want to identify")], outputs=gr.Image(type="pil"), title=title, description=description, article=article, examples=examples) interface.launch(debug=True)