File size: 2,701 Bytes
d64e6bd
 
 
 
 
be35f94
 
d64e6bd
be35f94
d64e6bd
 
be35f94
d64e6bd
 
 
 
 
 
 
 
 
be35f94
d64e6bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be35f94
d64e6bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from transformers import pipeline, SegGptImageProcessor, SegGptForImageSegmentation
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import spaces
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

depth_anything = pipeline(task = "depth-estimation", model="nielsr/depth-anything-small", device=device)
checkpoint = "BAAI/seggpt-vit-large"
image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
model = SegGptForImageSegmentation.from_pretrained(checkpoint).to(device)

def infer_seggpt(image_input, image_prompt, mask_prompt):
  num_labels = 100
  inputs = image_processor(
        images=image_input,
        prompt_images=image_prompt,
        prompt_masks=mask_prompt,
        return_tensors="pt",
        num_labels=num_labels
    ).to(device)
  with torch.no_grad():
      outputs = model(**inputs)

  target_sizes = [image_input.shape[:2]]

  mask = image_processor.post_process_semantic_segmentation(outputs, target_sizes, num_labels=num_labels)[0]
  palette = image_processor.get_palette(num_labels)
  fig, ax = plt.subplots()
  plt.gca().get_xaxis().get_major_formatter().set_useOffset(False)
  mask_rgb = image_processor.mask_to_rgb(mask.cpu().numpy(), palette, data_format="channels_last")
  print(mask_rgb.shape, image_input.shape)
  ax.imshow(Image.fromarray(image_input))
  ax.imshow(mask_rgb, cmap='viridis', alpha=0.6)

  ax.axis("off")
  ax.margins(0)
  plt.show()
  plt.savefig("masks.png", bbox_inches='tight', pad_inches=0)
  return "masks.png"

@spaces.GPU
def infer(image_input, image_prompt, mask_prompt):
  sg_masks = []
  mask_prompt = depth_anything(image_prompt)["depth"].convert("RGB")

  sg_mask = infer_seggpt(np.asarray(image_input), np.asarray(image_prompt),
                        np.asarray(mask_prompt))
  
  return sg_mask

import gradio as gr

demo = gr.Interface(
    infer,
    inputs=[gr.Image(type="pil", label="Image Input"), gr.Image(type="pil", label="Image Prompt")],
    outputs=[gr.Image(type="filepath", label="Mask Output")],
             #gr.Image(type="numpy", label="Output Mask")],
    title="SegGPT 🤝 Depth Anything: Speak to Segmentation in Image",
    description="SegGPT is a one-shot image segmentation model where one could ask model what to segment through uploading an example image and an example mask, and ask to segment the same thing in another image. In this demo, we have combined SegGPT and Depth Anything to automatically generate the mask for most outstanding object and segment the same thing in another image for you. You can see how it works by trying the example.",
    
    examples=[
        ["./cats.png", "./cat.png"],
    ])
demo.launch(debug=True)