merve HF staff commited on
Commit
d64e6bd
·
verified ·
1 Parent(s): fb9eac5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -4
app.py CHANGED
@@ -1,4 +1,63 @@
1
- transformers
2
- torch
3
- spaces
4
- matplotlib
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, SegGptImageProcessor, SegGptForImageSegmentation
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+
7
+ depth_anything = pipeline(task = "depth-estimation", model="nielsr/depth-anything-small", device=0)
8
+ checkpoint = "BAAI/seggpt-vit-large"
9
+ image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
10
+ model = SegGptForImageSegmentation.from_pretrained(checkpoint)
11
+
12
+ def infer_seggpt(image_input, image_prompt, mask_prompt):
13
+ num_labels = 100
14
+ inputs = image_processor(
15
+ images=image_input,
16
+ prompt_images=image_prompt,
17
+ prompt_masks=mask_prompt,
18
+ return_tensors="pt",
19
+ num_labels=num_labels
20
+ )
21
+ with torch.no_grad():
22
+ outputs = model(**inputs)
23
+
24
+ target_sizes = [image_input.shape[:2]]
25
+
26
+ mask = image_processor.post_process_semantic_segmentation(outputs, target_sizes, num_labels=num_labels)[0]
27
+ palette = image_processor.get_palette(num_labels)
28
+ fig, ax = plt.subplots()
29
+ plt.gca().get_xaxis().get_major_formatter().set_useOffset(False)
30
+ mask_rgb = image_processor.mask_to_rgb(mask.cpu().numpy(), palette, data_format="channels_last")
31
+ print(mask_rgb.shape, image_input.shape)
32
+ ax.imshow(Image.fromarray(image_input))
33
+ ax.imshow(mask_rgb, cmap='viridis', alpha=0.6)
34
+
35
+ ax.axis("off")
36
+ ax.margins(0)
37
+ plt.show()
38
+ plt.savefig("masks.png", bbox_inches='tight', pad_inches=0)
39
+ return "masks.png"
40
+
41
+ def infer(image_input, image_prompt, mask_prompt):
42
+ sg_masks = []
43
+ mask_prompt = depth_anything(image_prompt)["depth"].convert("RGB")
44
+
45
+ sg_mask = infer_seggpt(np.asarray(image_input), np.asarray(image_prompt),
46
+ np.asarray(mask_prompt))
47
+
48
+ return sg_mask
49
+
50
+ import gradio as gr
51
+
52
+ demo = gr.Interface(
53
+ infer,
54
+ inputs=[gr.Image(type="pil", label="Image Input"), gr.Image(type="pil", label="Image Prompt")],
55
+ outputs=[gr.Image(type="filepath", label="Mask Output")],
56
+ #gr.Image(type="numpy", label="Output Mask")],
57
+ title="SegGPT 🤝 Depth Anything: Speak to Segmentation in Image",
58
+ description="SegGPT is a one-shot image segmentation model where one could ask model what to segment through uploading an example image and an example mask, and ask to segment the same thing in another image. In this demo, we have combined SegGPT and Depth Anything to automatically generate the mask for most outstanding object and segment the same thing in another image for you. You can see how it works by trying the example.",
59
+
60
+ examples=[
61
+ ["./cats.png", "./cat.png"],
62
+ ])
63
+ demo.launch(debug=True)