from transformers import pipeline, SegGptImageProcessor, SegGptForImageSegmentation import torch import numpy as np from PIL import Image import matplotlib.pyplot as plt import spaces device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') depth_anything = pipeline(task = "depth-estimation", model="nielsr/depth-anything-small", device=device) checkpoint = "BAAI/seggpt-vit-large" image_processor = SegGptImageProcessor.from_pretrained(checkpoint) model = SegGptForImageSegmentation.from_pretrained(checkpoint).to(device) def infer_seggpt(image_input, image_prompt, mask_prompt): num_labels = 100 inputs = image_processor( images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt", num_labels=num_labels ).to(device) with torch.no_grad(): outputs = model(**inputs) target_sizes = [image_input.shape[:2]] mask = image_processor.post_process_semantic_segmentation(outputs, target_sizes, num_labels=num_labels)[0] palette = image_processor.get_palette(num_labels) fig, ax = plt.subplots() plt.gca().get_xaxis().get_major_formatter().set_useOffset(False) mask_rgb = image_processor.mask_to_rgb(mask.cpu().numpy(), palette, data_format="channels_last") print(mask_rgb.shape, image_input.shape) ax.imshow(Image.fromarray(image_input)) ax.imshow(mask_rgb, cmap='viridis', alpha=0.6) ax.axis("off") ax.margins(0) plt.show() plt.savefig("masks.png", bbox_inches='tight', pad_inches=0) return "masks.png" @spaces.GPU def infer(image_input, image_prompt, mask_prompt): sg_masks = [] mask_prompt = depth_anything(image_prompt)["depth"].convert("RGB") sg_mask = infer_seggpt(np.asarray(image_input), np.asarray(image_prompt), np.asarray(mask_prompt)) return sg_mask import gradio as gr demo = gr.Interface( infer, inputs=[gr.Image(type="pil", label="Image Input"), gr.Image(type="pil", label="Image Prompt")], outputs=[gr.Image(type="filepath", label="Mask Output")], #gr.Image(type="numpy", label="Output Mask")], title="SegGPT 🤝 Depth Anything: Speak to Segmentation in Image", description="SegGPT is a one-shot image segmentation model where one could ask model what to segment through uploading an example image and an example mask, and ask to segment the same thing in another image. In this demo, we have combined SegGPT and Depth Anything to automatically generate the mask for most outstanding object and segment the same thing in another image for you. You can see how it works by trying the example.", examples=[ ["./cats.png", "./cat.png"], ]) demo.launch(debug=True)