Spaces:
Sleeping
Sleeping
JensParslov
commited on
Commit
•
b3da277
0
Parent(s):
Duplicate from NN-BRD/hackathon_depth_segment
Browse files- .gitattributes +35 -0
- README.md +13 -0
- app.py +309 -0
- app_legacy.py +48 -0
- inference.py +448 -0
- packages.txt +1 -0
- requirements.txt +12 -0
- sam_vit_b_01ec64.pth +3 -0
- sam_vit_h_4b8939.pth +3 -0
- tests.py +0 -0
- utils.py +231 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Hackathon
|
3 |
+
emoji: 👁
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.39.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
duplicated_from: NN-BRD/hackathon_depth_segment
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
from PIL import Image, ImageOps
|
6 |
+
import torch
|
7 |
+
from inference import SegmentPredictor, DepthPredictor
|
8 |
+
from utils import generate_PCL, PCL3, point_cloud
|
9 |
+
|
10 |
+
|
11 |
+
sam = SegmentPredictor()
|
12 |
+
sam_cpu = SegmentPredictor(device="cpu")
|
13 |
+
dpt = DepthPredictor()
|
14 |
+
red = (255, 0, 0)
|
15 |
+
blue = (0, 0, 255)
|
16 |
+
annos = []
|
17 |
+
|
18 |
+
|
19 |
+
block = gr.Blocks()
|
20 |
+
with block:
|
21 |
+
# States
|
22 |
+
def point_coords_empty():
|
23 |
+
return []
|
24 |
+
|
25 |
+
def point_labels_empty():
|
26 |
+
return []
|
27 |
+
|
28 |
+
image_edit_trigger = gr.State(True)
|
29 |
+
point_coords = gr.State(point_coords_empty)
|
30 |
+
point_labels = gr.State(point_labels_empty)
|
31 |
+
masks = gr.State([])
|
32 |
+
cutout_idx = gr.State(set())
|
33 |
+
pred_masks = gr.State([])
|
34 |
+
prompt_masks = gr.State([])
|
35 |
+
embedding = gr.State()
|
36 |
+
|
37 |
+
# UI
|
38 |
+
with gr.Column():
|
39 |
+
gr.Markdown(
|
40 |
+
"""# Segment Anything Model (SAM)
|
41 |
+
## a new AI model from Meta AI that can "cut out" any object, in any image, with a single click 🚀
|
42 |
+
SAM is a promptable segmentation system with zero-shot generalization to unfamiliar objects and images, without the need for additional training. [**Official Project**](https://segment-anything.com/) [**Code**](https://github.com/facebookresearch/segment-anything).
|
43 |
+
"""
|
44 |
+
)
|
45 |
+
with gr.Row():
|
46 |
+
with gr.Column():
|
47 |
+
with gr.Tab("Upload Image"):
|
48 |
+
# mirror_webcam = False
|
49 |
+
upload_image = gr.Image(label="Input", type="pil", tool=None)
|
50 |
+
with gr.Tab("Webcam"):
|
51 |
+
# mirror_webcam = False
|
52 |
+
input_image = gr.Image(
|
53 |
+
label="Input", type="pil", tool=None, source="webcam"
|
54 |
+
)
|
55 |
+
with gr.Row():
|
56 |
+
sam_encode_btn = gr.Button("Encode", variant="primary")
|
57 |
+
sam_sgmt_everything_btn = gr.Button(
|
58 |
+
"Segment Everything!", variant="primary"
|
59 |
+
)
|
60 |
+
# sam_encode_status = gr.Label('Not encoded yet')
|
61 |
+
with gr.Row():
|
62 |
+
prompt_image = gr.Image(label="Segments")
|
63 |
+
# prompt_lbl_image = gr.AnnotatedImage(label='Segment Labels')
|
64 |
+
lbl_image = gr.AnnotatedImage(label="Everything")
|
65 |
+
with gr.Row():
|
66 |
+
point_label_radio = gr.Radio(label="Point Label", choices=[1, 0], value=1)
|
67 |
+
text = gr.Textbox(label="Mask Name")
|
68 |
+
reset_btn = gr.Button("New Mask")
|
69 |
+
selected_masks_image = gr.AnnotatedImage(label="Selected Masks")
|
70 |
+
with gr.Row():
|
71 |
+
with gr.Column():
|
72 |
+
pcl_figure = gr.Model3D(
|
73 |
+
label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0]
|
74 |
+
)
|
75 |
+
with gr.Row():
|
76 |
+
max_depth = gr.Slider(
|
77 |
+
minimum=0, maximum=10, value=3, step=0.01, label="Max Depth"
|
78 |
+
)
|
79 |
+
min_depth = gr.Slider(
|
80 |
+
minimum=0, maximum=10, step=0.01, value=1, label="Min Depth"
|
81 |
+
)
|
82 |
+
n_samples = gr.Slider(
|
83 |
+
minimum=1e3,
|
84 |
+
maximum=1e6,
|
85 |
+
step=1e3,
|
86 |
+
value=1e5,
|
87 |
+
label="Number of Samples",
|
88 |
+
)
|
89 |
+
cube_size = gr.Slider(
|
90 |
+
minimum=0.00001,
|
91 |
+
maximum=0.001,
|
92 |
+
step=0.000001,
|
93 |
+
default=0.00001,
|
94 |
+
label="Cube size",
|
95 |
+
)
|
96 |
+
depth_reconstruction_btn = gr.Button(
|
97 |
+
"3D Reconstruction", variant="primary"
|
98 |
+
)
|
99 |
+
depth_reconstruction_mask_btn = gr.Button(
|
100 |
+
"Mask Reconstruction", variant="primary"
|
101 |
+
)
|
102 |
+
|
103 |
+
sam_decode_btn = gr.Button("Predict using points!", variant="primary")
|
104 |
+
|
105 |
+
# components
|
106 |
+
components = {
|
107 |
+
point_coords,
|
108 |
+
point_labels,
|
109 |
+
image_edit_trigger,
|
110 |
+
masks,
|
111 |
+
cutout_idx,
|
112 |
+
input_image,
|
113 |
+
embedding,
|
114 |
+
point_label_radio,
|
115 |
+
text,
|
116 |
+
reset_btn,
|
117 |
+
sam_sgmt_everything_btn,
|
118 |
+
sam_decode_btn,
|
119 |
+
depth_reconstruction_btn,
|
120 |
+
prompt_image,
|
121 |
+
lbl_image,
|
122 |
+
n_samples,
|
123 |
+
max_depth,
|
124 |
+
min_depth,
|
125 |
+
cube_size,
|
126 |
+
selected_masks_image,
|
127 |
+
}
|
128 |
+
|
129 |
+
def on_upload_image(input_image, upload_image):
|
130 |
+
# Mirror because gradio.image webcam has mirror = True
|
131 |
+
upload_image_mirror = ImageOps.mirror(upload_image)
|
132 |
+
return [upload_image_mirror, upload_image]
|
133 |
+
|
134 |
+
upload_image.upload(
|
135 |
+
on_upload_image, [input_image, upload_image], [input_image, upload_image]
|
136 |
+
)
|
137 |
+
|
138 |
+
# event - init coords
|
139 |
+
def on_reset_btn_click(input_image):
|
140 |
+
return input_image, point_coords_empty(), point_labels_empty(), None, []
|
141 |
+
|
142 |
+
reset_btn.click(
|
143 |
+
on_reset_btn_click,
|
144 |
+
[input_image],
|
145 |
+
[prompt_image, point_coords, point_labels],
|
146 |
+
queue=False,
|
147 |
+
)
|
148 |
+
|
149 |
+
def on_prompt_image_select(
|
150 |
+
input_image,
|
151 |
+
prompt_image,
|
152 |
+
point_coords,
|
153 |
+
point_labels,
|
154 |
+
point_label_radio,
|
155 |
+
text,
|
156 |
+
pred_masks,
|
157 |
+
embedding,
|
158 |
+
evt: gr.SelectData,
|
159 |
+
):
|
160 |
+
sam_cpu.dummy_encode(input_image)
|
161 |
+
x, y = evt.index
|
162 |
+
color = red if point_label_radio == 0 else blue
|
163 |
+
if prompt_image is None:
|
164 |
+
prompt_image = np.array(input_image.copy())
|
165 |
+
|
166 |
+
cv2.circle(prompt_image, (x, y), 5, color, -1)
|
167 |
+
point_coords.append([x, y])
|
168 |
+
point_labels.append(point_label_radio)
|
169 |
+
sam_masks = sam_cpu.cond_pred(
|
170 |
+
pts=np.array(point_coords), lbls=np.array(point_labels), embedding=embedding
|
171 |
+
)
|
172 |
+
return [
|
173 |
+
prompt_image,
|
174 |
+
(input_image, sam_masks),
|
175 |
+
point_coords,
|
176 |
+
point_labels,
|
177 |
+
sam_masks,
|
178 |
+
]
|
179 |
+
|
180 |
+
prompt_image.select(
|
181 |
+
on_prompt_image_select,
|
182 |
+
[
|
183 |
+
input_image,
|
184 |
+
prompt_image,
|
185 |
+
point_coords,
|
186 |
+
point_labels,
|
187 |
+
point_label_radio,
|
188 |
+
text,
|
189 |
+
pred_masks,
|
190 |
+
embedding,
|
191 |
+
],
|
192 |
+
[prompt_image, lbl_image, point_coords, point_labels, pred_masks],
|
193 |
+
queue=True,
|
194 |
+
)
|
195 |
+
|
196 |
+
def on_everything_image_select(
|
197 |
+
input_image, pred_masks, masks, text, evt: gr.SelectData
|
198 |
+
):
|
199 |
+
i = evt.index
|
200 |
+
mask = pred_masks[i][0]
|
201 |
+
print(mask)
|
202 |
+
print(type(mask))
|
203 |
+
masks.append((mask, text))
|
204 |
+
anno = (input_image, masks)
|
205 |
+
return [masks, anno]
|
206 |
+
|
207 |
+
lbl_image.select(
|
208 |
+
on_everything_image_select,
|
209 |
+
[input_image, pred_masks, masks, text],
|
210 |
+
[masks, selected_masks_image],
|
211 |
+
queue=False,
|
212 |
+
)
|
213 |
+
|
214 |
+
def on_selected_masks_image_select(input_image, masks, evt: gr.SelectData):
|
215 |
+
i = evt.index
|
216 |
+
del masks[i]
|
217 |
+
anno = (input_image, masks)
|
218 |
+
return [masks, anno]
|
219 |
+
|
220 |
+
selected_masks_image.select(
|
221 |
+
on_selected_masks_image_select,
|
222 |
+
[input_image, masks],
|
223 |
+
[masks, selected_masks_image],
|
224 |
+
queue=False,
|
225 |
+
)
|
226 |
+
# prompt_lbl_image.select(on_everything_image_select,
|
227 |
+
# [input_image, prompt_masks, masks, text],
|
228 |
+
# [masks, selected_masks_image], queue=False)
|
229 |
+
|
230 |
+
def on_click_sam_encode_btn(inputs):
|
231 |
+
print("encoding")
|
232 |
+
# encode image on click
|
233 |
+
embedding = sam.encode(inputs[input_image]).cpu()
|
234 |
+
sam_cpu.dummy_encode(inputs[input_image])
|
235 |
+
print("encoding done")
|
236 |
+
return [inputs[input_image], embedding]
|
237 |
+
|
238 |
+
sam_encode_btn.click(
|
239 |
+
on_click_sam_encode_btn, components, [prompt_image, embedding], queue=False
|
240 |
+
)
|
241 |
+
|
242 |
+
def on_click_sam_dencode_btn(inputs):
|
243 |
+
print("inferencing")
|
244 |
+
image = inputs[input_image]
|
245 |
+
generated_mask, _, _ = sam.cond_pred(
|
246 |
+
pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels])
|
247 |
+
)
|
248 |
+
inputs[masks].append((generated_mask, inputs[text]))
|
249 |
+
print(inputs[masks][0])
|
250 |
+
return {prompt_image: (image, inputs[masks])}
|
251 |
+
|
252 |
+
sam_decode_btn.click(
|
253 |
+
on_click_sam_dencode_btn,
|
254 |
+
components,
|
255 |
+
[prompt_image, masks, cutout_idx],
|
256 |
+
queue=True,
|
257 |
+
)
|
258 |
+
|
259 |
+
def on_depth_reconstruction_btn_click(inputs):
|
260 |
+
print("depth reconstruction")
|
261 |
+
path = dpt.generate_obj_rgb(
|
262 |
+
image=inputs[input_image],
|
263 |
+
cube_size=inputs[cube_size],
|
264 |
+
n_samples=inputs[n_samples],
|
265 |
+
# masks=inputs[masks],
|
266 |
+
min_depth=inputs[min_depth],
|
267 |
+
max_depth=inputs[max_depth],
|
268 |
+
)
|
269 |
+
return {pcl_figure: path}
|
270 |
+
|
271 |
+
depth_reconstruction_btn.click(
|
272 |
+
on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False
|
273 |
+
)
|
274 |
+
|
275 |
+
def on_depth_reconstruction_mask_btn_click(inputs):
|
276 |
+
print("depth reconstruction")
|
277 |
+
path = dpt.generate_obj_masks2(
|
278 |
+
image=inputs[input_image],
|
279 |
+
cube_size=inputs[cube_size],
|
280 |
+
n_samples=inputs[n_samples],
|
281 |
+
masks=inputs[masks],
|
282 |
+
min_depth=inputs[min_depth],
|
283 |
+
max_depth=inputs[max_depth],
|
284 |
+
)
|
285 |
+
return {pcl_figure: path}
|
286 |
+
|
287 |
+
depth_reconstruction_mask_btn.click(
|
288 |
+
on_depth_reconstruction_mask_btn_click, components, [pcl_figure], queue=False
|
289 |
+
)
|
290 |
+
|
291 |
+
def on_sam_sgmt_everything_btn_click(inputs):
|
292 |
+
print("segmenting everything")
|
293 |
+
image = inputs[input_image]
|
294 |
+
sam_masks = sam.segment_everything(image)
|
295 |
+
print(image)
|
296 |
+
print(sam_masks)
|
297 |
+
return [(image, sam_masks), sam_masks]
|
298 |
+
|
299 |
+
sam_sgmt_everything_btn.click(
|
300 |
+
on_sam_sgmt_everything_btn_click,
|
301 |
+
components,
|
302 |
+
[lbl_image, pred_masks],
|
303 |
+
queue=True,
|
304 |
+
)
|
305 |
+
|
306 |
+
|
307 |
+
if __name__ == "__main__":
|
308 |
+
block.queue()
|
309 |
+
block.launch(auth=("novouser", "bstad2023"))
|
app_legacy.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
3 |
+
import supervision as sv
|
4 |
+
from inference import DepthPredictor, SegmentPredictor
|
5 |
+
from utils import create_3d_obj, create_3d_pc, point_cloud, generate_PCL
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def produce_depth_map(image):
|
9 |
+
depth_predictor = DepthPredictor()
|
10 |
+
depth_result = depth_predictor.predict(image)
|
11 |
+
return depth_result
|
12 |
+
|
13 |
+
def produce_segmentation_map(image):
|
14 |
+
segment_predictor = SegmentPredictor()
|
15 |
+
sam_result = segment_predictor.predict(image)
|
16 |
+
return sam_result
|
17 |
+
|
18 |
+
def produce_3d_reconstruction(image):
|
19 |
+
depth_predictor = DepthPredictor()
|
20 |
+
depth_result = depth_predictor.predict(image)
|
21 |
+
rgb_gltf_path = create_3d_obj(np.array(image), depth_result, path='./rgb.gltf')
|
22 |
+
return rgb_gltf_path
|
23 |
+
|
24 |
+
def produce_point_cloud(depth_map, segmentation_map):
|
25 |
+
return point_cloud(np.array(segmentation_map), depth_map)
|
26 |
+
|
27 |
+
def snap(image, depth_map, segmentation_map):
|
28 |
+
depth_result = produce_depth_map(image) if depth_map else None
|
29 |
+
sam_result = produce_segmentation_map(image) if segmentation_map else None
|
30 |
+
rgb_gltf_path = produce_3d_reconstruction(image) if depth_map else None
|
31 |
+
point_cloud_fig = produce_point_cloud(depth_result, sam_result) if (segmentation_map and depth_map) else None
|
32 |
+
|
33 |
+
return [image, depth_result, sam_result, rgb_gltf_path, point_cloud_fig]
|
34 |
+
demo = gr.Interface(
|
35 |
+
snap,
|
36 |
+
inputs=[gr.Image(source="webcam", tool=None, label="Input Image", type="pil"),
|
37 |
+
"checkbox",
|
38 |
+
"checkbox"],
|
39 |
+
outputs=[gr.Image(label="RGB"),
|
40 |
+
gr.Image(label="predicted depth"),
|
41 |
+
gr.Image(label="predicted segmentation"),
|
42 |
+
gr.Model3D(label="3D mesh reconstruction - RGB",
|
43 |
+
clear_color=[1.0, 1.0, 1.0, 1.0]),
|
44 |
+
gr.Plot()]
|
45 |
+
)
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
demo.launch()
|
inference.py
ADDED
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import DPTImageProcessor, DPTForDepthEstimation
|
2 |
+
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor
|
3 |
+
import gradio as gr
|
4 |
+
import supervision as sv
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import requests
|
9 |
+
import open3d as o3d
|
10 |
+
import pandas as pd
|
11 |
+
import plotly.express as px
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
|
14 |
+
def remove_outliers(point_cloud, threshold=3.0):
|
15 |
+
# Calculate mean and standard deviation along each dimension
|
16 |
+
mean = np.mean(point_cloud, axis=0)
|
17 |
+
std = np.std(point_cloud, axis=0)
|
18 |
+
|
19 |
+
# Define lower and upper bounds for each dimension
|
20 |
+
lower_bounds = mean - threshold * std
|
21 |
+
upper_bounds = mean + threshold * std
|
22 |
+
|
23 |
+
# Create a boolean mask for points within the bounds
|
24 |
+
mask = np.all((point_cloud >= lower_bounds) & (point_cloud <= upper_bounds), axis=1)
|
25 |
+
|
26 |
+
# Filter out outlier points
|
27 |
+
filtered_point_cloud = point_cloud[mask]
|
28 |
+
|
29 |
+
return filtered_point_cloud
|
30 |
+
|
31 |
+
|
32 |
+
def map_image_range(depth, min_value, max_value):
|
33 |
+
"""
|
34 |
+
Maps the values of a numpy image array to a specified range.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
image (numpy.ndarray): Input image array with values ranging from 0 to 1.
|
38 |
+
min_value (float): Minimum value of the new range.
|
39 |
+
max_value (float): Maximum value of the new range.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
numpy.ndarray: Image array with values mapped to the specified range.
|
43 |
+
"""
|
44 |
+
# Ensure the input image is a numpy array
|
45 |
+
print(np.min(depth))
|
46 |
+
print(np.max(depth))
|
47 |
+
depth = np.array(depth)
|
48 |
+
# map the depth values are between 0 and 1
|
49 |
+
depth = (depth - depth.min()) / (depth.max() - depth.min())
|
50 |
+
# invert
|
51 |
+
depth = 1 - depth
|
52 |
+
print(np.min(depth))
|
53 |
+
print(np.max(depth))
|
54 |
+
# Map the values to the specified range
|
55 |
+
mapped_image = (depth - 0) * (max_value - min_value) / (1 - 0) + min_value
|
56 |
+
print(np.min(mapped_image))
|
57 |
+
print(np.max(mapped_image))
|
58 |
+
return mapped_image
|
59 |
+
|
60 |
+
|
61 |
+
def PCL(mask, depth):
|
62 |
+
assert mask.shape == depth.shape
|
63 |
+
assert type(mask) == np.ndarray
|
64 |
+
assert type(depth) == np.ndarray
|
65 |
+
rgb_mask = np.zeros((mask.shape[0], mask.shape[1], 3)).astype("uint8")
|
66 |
+
rgb_mask[mask] = (255, 0, 0)
|
67 |
+
print(np.unique(rgb_mask))
|
68 |
+
depth_o3d = o3d.geometry.Image(depth)
|
69 |
+
image_o3d = o3d.geometry.Image(rgb_mask)
|
70 |
+
# print(len(depth_o3d))
|
71 |
+
# print(len(image_o3d))
|
72 |
+
rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
|
73 |
+
image_o3d, depth_o3d, convert_rgb_to_intensity=False
|
74 |
+
)
|
75 |
+
# Step 3: Create a PointCloud from the RGBD image
|
76 |
+
pcd = o3d.geometry.PointCloud.create_from_rgbd_image(
|
77 |
+
rgbd_image,
|
78 |
+
o3d.camera.PinholeCameraIntrinsic(
|
79 |
+
o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault
|
80 |
+
),
|
81 |
+
)
|
82 |
+
# Step 4: Convert PointCloud data to a NumPy array
|
83 |
+
# print(len(pcd))
|
84 |
+
points = np.asarray(pcd.points)
|
85 |
+
colors = np.asarray(pcd.colors)
|
86 |
+
print(np.unique(colors, axis=0))
|
87 |
+
print(np.unique(colors, axis=1))
|
88 |
+
print(np.unique(colors))
|
89 |
+
mask = colors[:, 0] == 1.0
|
90 |
+
print(mask.sum())
|
91 |
+
print(colors.shape)
|
92 |
+
points = points[mask]
|
93 |
+
colors = colors[mask]
|
94 |
+
return points, colors
|
95 |
+
|
96 |
+
|
97 |
+
def PCL_rgb(rgb, depth):
|
98 |
+
# assert rgb.shape == depth.shape
|
99 |
+
assert type(rgb) == np.ndarray
|
100 |
+
assert type(depth) == np.ndarray
|
101 |
+
depth_o3d = o3d.geometry.Image(depth)
|
102 |
+
image_o3d = o3d.geometry.Image(rgb)
|
103 |
+
rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
|
104 |
+
image_o3d, depth_o3d, convert_rgb_to_intensity=False
|
105 |
+
)
|
106 |
+
# Step 3: Create a PointCloud from the RGBD image
|
107 |
+
pcd = o3d.geometry.PointCloud.create_from_rgbd_image(
|
108 |
+
rgbd_image,
|
109 |
+
o3d.camera.PinholeCameraIntrinsic(
|
110 |
+
o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault
|
111 |
+
),
|
112 |
+
)
|
113 |
+
# Step 4: Convert PointCloud data to a NumPy array
|
114 |
+
points = np.asarray(pcd.points)
|
115 |
+
colors = np.asarray(pcd.colors)
|
116 |
+
return points, colors
|
117 |
+
|
118 |
+
|
119 |
+
class DepthPredictor:
|
120 |
+
def __init__(self):
|
121 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
122 |
+
self.feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
|
123 |
+
self.model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
|
124 |
+
self.model.eval()
|
125 |
+
|
126 |
+
def predict(self, image):
|
127 |
+
# prepare image for the model
|
128 |
+
encoding = self.feature_extractor(image, return_tensors="pt")
|
129 |
+
# forward pass
|
130 |
+
with torch.no_grad():
|
131 |
+
outputs = self.model(**encoding)
|
132 |
+
predicted_depth = outputs.predicted_depth
|
133 |
+
# interpolate to original size
|
134 |
+
prediction = torch.nn.functional.interpolate(
|
135 |
+
predicted_depth.unsqueeze(1),
|
136 |
+
size=image.size[::-1],
|
137 |
+
mode="bicubic",
|
138 |
+
align_corners=False,
|
139 |
+
).squeeze()
|
140 |
+
|
141 |
+
output = prediction.cpu().numpy()
|
142 |
+
# output = 1 - (output/np.max(output))
|
143 |
+
return output
|
144 |
+
|
145 |
+
def generate_pcl(self, image):
|
146 |
+
print(np.array(image).shape)
|
147 |
+
depth = self.predict(image)
|
148 |
+
print(depth.shape)
|
149 |
+
# Step 2: Create an RGBD image from the RGB and depth image
|
150 |
+
depth_o3d = o3d.geometry.Image(depth)
|
151 |
+
image_o3d = o3d.geometry.Image(np.array(image))
|
152 |
+
rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
|
153 |
+
image_o3d, depth_o3d, convert_rgb_to_intensity=False
|
154 |
+
)
|
155 |
+
# Step 3: Create a PointCloud from the RGBD image
|
156 |
+
pcd = o3d.geometry.PointCloud.create_from_rgbd_image(
|
157 |
+
rgbd_image,
|
158 |
+
o3d.camera.PinholeCameraIntrinsic(
|
159 |
+
o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault
|
160 |
+
),
|
161 |
+
)
|
162 |
+
# Step 4: Convert PointCloud data to a NumPy array
|
163 |
+
points = np.asarray(pcd.points)
|
164 |
+
colors = np.asarray(pcd.colors)
|
165 |
+
print(points.shape, colors.shape)
|
166 |
+
return points, colors
|
167 |
+
|
168 |
+
def generate_fig(self, image):
|
169 |
+
points, colors = self.generate_pcl(image)
|
170 |
+
data = {
|
171 |
+
"x": points[:, 0],
|
172 |
+
"y": points[:, 1],
|
173 |
+
"z": points[:, 2],
|
174 |
+
"red": colors[:, 0],
|
175 |
+
"green": colors[:, 1],
|
176 |
+
"blue": colors[:, 2],
|
177 |
+
}
|
178 |
+
df = pd.DataFrame(data)
|
179 |
+
size = np.zeros(len(df))
|
180 |
+
size[:] = 0.01
|
181 |
+
# Step 6: Create a 3D scatter plot using Plotly Express
|
182 |
+
fig = px.scatter_3d(df, x="x", y="y", z="z", color="red", size=size)
|
183 |
+
return fig
|
184 |
+
|
185 |
+
def generate_fig2(self, image):
|
186 |
+
points, colors = self.generate_pcl(image)
|
187 |
+
# Step 6: Create a 3D scatter plot using Plotly Express
|
188 |
+
fig = plt.figure()
|
189 |
+
ax = fig.add_subplot(111, projection="3d")
|
190 |
+
ax.scatter(points, size=0.01, c=colors, marker="o")
|
191 |
+
return fig
|
192 |
+
|
193 |
+
def generate_obj_rgb(self, image, n_samples, cube_size, max_depth, min_depth):
|
194 |
+
# Step 1: Create a point cloud
|
195 |
+
depth = self.predict(image)
|
196 |
+
image = np.array(image)
|
197 |
+
depth = map_image_range(depth, min_depth, max_depth)
|
198 |
+
point_cloud, color_array = PCL_rgb(image, depth)
|
199 |
+
idxs = np.random.choice(len(point_cloud), int(n_samples))
|
200 |
+
point_cloud = point_cloud[idxs]
|
201 |
+
color_array = color_array[idxs]
|
202 |
+
# Create a mesh to hold the colored cubes
|
203 |
+
mesh = o3d.geometry.TriangleMesh()
|
204 |
+
# Create cubes and add them to the mesh
|
205 |
+
for point, color in zip(point_cloud, color_array):
|
206 |
+
cube = o3d.geometry.TriangleMesh.create_box(
|
207 |
+
width=cube_size, height=cube_size, depth=cube_size
|
208 |
+
)
|
209 |
+
cube.translate(-point)
|
210 |
+
cube.paint_uniform_color(color)
|
211 |
+
mesh += cube
|
212 |
+
# Save the mesh to an .obj file
|
213 |
+
output_file = "./cloud.obj"
|
214 |
+
o3d.io.write_triangle_mesh(output_file, mesh)
|
215 |
+
return output_file
|
216 |
+
|
217 |
+
def generate_obj_masks(self, image, n_samples, masks, cube_size):
|
218 |
+
# Generate a point cloud
|
219 |
+
point_cloud, color_array = self.generate_pcl(image)
|
220 |
+
print(point_cloud.shape)
|
221 |
+
mesh = o3d.geometry.TriangleMesh()
|
222 |
+
# Create cubes and add them to the mesh
|
223 |
+
cs = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]
|
224 |
+
for c, (mask, _) in zip(cs, masks):
|
225 |
+
mask = mask.ravel()
|
226 |
+
point_cloud_subset, color_array_subset = (
|
227 |
+
point_cloud[mask],
|
228 |
+
color_array[mask],
|
229 |
+
)
|
230 |
+
idxs = np.random.choice(len(point_cloud_subset), int(n_samples))
|
231 |
+
point_cloud_subset = point_cloud_subset[idxs]
|
232 |
+
for point in point_cloud_subset:
|
233 |
+
cube = o3d.geometry.TriangleMesh.create_box(
|
234 |
+
width=cube_size, height=cube_size, depth=cube_size
|
235 |
+
)
|
236 |
+
cube.translate(-point)
|
237 |
+
cube.paint_uniform_color(c)
|
238 |
+
mesh += cube
|
239 |
+
# Save the mesh to an .obj file
|
240 |
+
output_file = "./cloud.obj"
|
241 |
+
o3d.io.write_triangle_mesh(output_file, mesh)
|
242 |
+
return output_file
|
243 |
+
|
244 |
+
def generate_obj_masks2(
|
245 |
+
self, image, masks, cube_size, n_samples, min_depth, max_depth
|
246 |
+
):
|
247 |
+
# Generate a point cloud
|
248 |
+
depth = self.predict(image)
|
249 |
+
depth = map_image_range(depth, min_depth, max_depth)
|
250 |
+
image = np.array(image)
|
251 |
+
mesh = o3d.geometry.TriangleMesh()
|
252 |
+
# Create cubes and add them to the mesh
|
253 |
+
print(len(masks))
|
254 |
+
cs = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
|
255 |
+
for c, (mask, _) in zip(cs, masks):
|
256 |
+
points, _ = PCL(mask, depth)
|
257 |
+
idxs = np.random.choice(len(points), int(n_samples))
|
258 |
+
points = points[idxs]
|
259 |
+
points = remove_outliers(points)
|
260 |
+
for point in points:
|
261 |
+
cube = o3d.geometry.TriangleMesh.create_box(
|
262 |
+
width=cube_size, height=cube_size, depth=cube_size
|
263 |
+
)
|
264 |
+
cube.translate(-point)
|
265 |
+
cube.paint_uniform_color(c)
|
266 |
+
mesh += cube
|
267 |
+
# Save the mesh to an .obj file
|
268 |
+
output_file = "./cloud.obj"
|
269 |
+
o3d.io.write_triangle_mesh(output_file, mesh)
|
270 |
+
return output_file
|
271 |
+
|
272 |
+
|
273 |
+
import numpy as np
|
274 |
+
from typing import Optional, Tuple
|
275 |
+
|
276 |
+
|
277 |
+
class CustomSamPredictor(SamPredictor):
|
278 |
+
def __init__(
|
279 |
+
self,
|
280 |
+
sam_model,
|
281 |
+
) -> None:
|
282 |
+
super().__init__(sam_model)
|
283 |
+
|
284 |
+
def encode_image(
|
285 |
+
self,
|
286 |
+
image: np.ndarray,
|
287 |
+
image_format: str = "RGB",
|
288 |
+
) -> None:
|
289 |
+
"""
|
290 |
+
Calculates the image embeddings for the provided image, allowing
|
291 |
+
masks to be predicted with the 'predict' method.
|
292 |
+
|
293 |
+
Arguments:
|
294 |
+
image (np.ndarray): The image for calculating masks. Expects an
|
295 |
+
image in HWC uint8 format, with pixel values in [0, 255].
|
296 |
+
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
297 |
+
"""
|
298 |
+
assert image_format in [
|
299 |
+
"RGB",
|
300 |
+
"BGR",
|
301 |
+
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
|
302 |
+
if image_format != self.model.image_format:
|
303 |
+
image = image[..., ::-1]
|
304 |
+
|
305 |
+
# Transform the image to the form expected by the model
|
306 |
+
input_image = self.transform.apply_image(image)
|
307 |
+
input_image_torch = torch.as_tensor(input_image, device=self.device)
|
308 |
+
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[
|
309 |
+
None, :, :, :
|
310 |
+
]
|
311 |
+
self.set_torch_image(input_image_torch, image.shape[:2])
|
312 |
+
return self.get_image_embedding()
|
313 |
+
|
314 |
+
def decode_and_predict(
|
315 |
+
self,
|
316 |
+
embedding: torch.Tensor,
|
317 |
+
point_coords: Optional[np.ndarray] = None,
|
318 |
+
point_labels: Optional[np.ndarray] = None,
|
319 |
+
box: Optional[np.ndarray] = None,
|
320 |
+
mask_input: Optional[np.ndarray] = None,
|
321 |
+
multimask_output: bool = True,
|
322 |
+
return_logits: bool = False,
|
323 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
324 |
+
"""
|
325 |
+
Decodes the provided image embedding and makes mask predictions based on prompts.
|
326 |
+
|
327 |
+
Arguments:
|
328 |
+
embedding (torch.Tensor): The image embedding to decode.
|
329 |
+
... (other arguments from the predict function)
|
330 |
+
|
331 |
+
Returns:
|
332 |
+
(np.ndarray): The output masks in CxHxW format.
|
333 |
+
(np.ndarray): An array of quality predictions for each mask.
|
334 |
+
(np.ndarray): Low resolution mask logits for subsequent iterations.
|
335 |
+
"""
|
336 |
+
self.features = embedding
|
337 |
+
self.is_image_set = True
|
338 |
+
return self.predict(
|
339 |
+
point_coords=point_coords,
|
340 |
+
point_labels=point_labels,
|
341 |
+
box=box,
|
342 |
+
mask_input=mask_input,
|
343 |
+
multimask_output=multimask_output,
|
344 |
+
return_logits=return_logits,
|
345 |
+
)
|
346 |
+
|
347 |
+
def dummy_set_torch_image(
|
348 |
+
self,
|
349 |
+
transformed_image: torch.Tensor,
|
350 |
+
original_image_size: Tuple[int, ...],
|
351 |
+
) -> None:
|
352 |
+
"""
|
353 |
+
Calculates the image embeddings for the provided image, allowing
|
354 |
+
masks to be predicted with the 'predict' method. Expects the input
|
355 |
+
image to be already transformed to the format expected by the model.
|
356 |
+
|
357 |
+
Arguments:
|
358 |
+
transformed_image (torch.Tensor): The input image, with shape
|
359 |
+
1x3xHxW, which has been transformed with ResizeLongestSide.
|
360 |
+
original_image_size (tuple(int, int)): The size of the image
|
361 |
+
before transformation, in (H, W) format.
|
362 |
+
"""
|
363 |
+
assert (
|
364 |
+
len(transformed_image.shape) == 4
|
365 |
+
and transformed_image.shape[1] == 3
|
366 |
+
and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
|
367 |
+
), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
|
368 |
+
self.reset_image()
|
369 |
+
|
370 |
+
self.original_size = original_image_size
|
371 |
+
self.input_size = tuple(transformed_image.shape[-2:])
|
372 |
+
input_image = self.model.preprocess(transformed_image)
|
373 |
+
# The following line is commented out to avoid encoding on cpu
|
374 |
+
# self.features = self.model.image_encoder(input_image)
|
375 |
+
self.is_image_set = True
|
376 |
+
|
377 |
+
def dummy_set_image(
|
378 |
+
self,
|
379 |
+
image: np.ndarray,
|
380 |
+
image_format: str = "RGB",
|
381 |
+
) -> None:
|
382 |
+
"""
|
383 |
+
Calculates the image embeddings for the provided image, allowing
|
384 |
+
masks to be predicted with the 'predict' method.
|
385 |
+
|
386 |
+
Arguments:
|
387 |
+
image (np.ndarray): The image for calculating masks. Expects an
|
388 |
+
image in HWC uint8 format, with pixel values in [0, 255].
|
389 |
+
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
390 |
+
"""
|
391 |
+
assert image_format in [
|
392 |
+
"RGB",
|
393 |
+
"BGR",
|
394 |
+
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
|
395 |
+
if image_format != self.model.image_format:
|
396 |
+
image = image[..., ::-1]
|
397 |
+
|
398 |
+
# Transform the image to the form expected by the model
|
399 |
+
input_image = self.transform.apply_image(image)
|
400 |
+
input_image_torch = torch.as_tensor(input_image, device=self.device)
|
401 |
+
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[
|
402 |
+
None, :, :, :
|
403 |
+
]
|
404 |
+
|
405 |
+
self.dummy_set_torch_image(input_image_torch, image.shape[:2])
|
406 |
+
|
407 |
+
|
408 |
+
class SegmentPredictor:
|
409 |
+
def __init__(self, device=None):
|
410 |
+
MODEL_TYPE = "vit_h"
|
411 |
+
checkpoint = "sam_vit_h_4b8939.pth"
|
412 |
+
sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint)
|
413 |
+
# Select device
|
414 |
+
if device is None:
|
415 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
416 |
+
else:
|
417 |
+
self.device = device
|
418 |
+
sam.to(device=self.device)
|
419 |
+
self.mask_generator = SamAutomaticMaskGenerator(sam)
|
420 |
+
self.conditioned_pred = CustomSamPredictor(sam)
|
421 |
+
|
422 |
+
def encode(self, image):
|
423 |
+
image = np.array(image)
|
424 |
+
return self.conditioned_pred.encode_image(image)
|
425 |
+
|
426 |
+
def dummy_encode(self, image):
|
427 |
+
image = np.array(image)
|
428 |
+
self.conditioned_pred.dummy_set_image(image)
|
429 |
+
|
430 |
+
def cond_pred(self, embedding, pts, lbls):
|
431 |
+
lbls = np.array(lbls)
|
432 |
+
pts = np.array(pts)
|
433 |
+
masks, _, _ = self.conditioned_pred.decode_and_predict(
|
434 |
+
embedding, point_coords=pts, point_labels=lbls, multimask_output=True
|
435 |
+
)
|
436 |
+
idxs = np.argsort(-masks.sum(axis=(1, 2)))
|
437 |
+
sam_masks = []
|
438 |
+
for n, i in enumerate(idxs):
|
439 |
+
sam_masks.append((masks[i], str(n)))
|
440 |
+
return sam_masks
|
441 |
+
|
442 |
+
def segment_everything(self, image):
|
443 |
+
image = np.array(image)
|
444 |
+
sam_result = self.mask_generator.generate(image)
|
445 |
+
sam_masks = []
|
446 |
+
for i, mask in enumerate(sam_result):
|
447 |
+
sam_masks.append((mask["segmentation"], str(i)))
|
448 |
+
return sam_masks
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
libgl1-mesa-glx
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
huggingface_hub
|
3 |
+
segment-anything
|
4 |
+
supervision
|
5 |
+
torch
|
6 |
+
torchvision
|
7 |
+
opencv-python
|
8 |
+
transformers
|
9 |
+
open3d
|
10 |
+
plotly
|
11 |
+
pandas
|
12 |
+
numpy
|
sam_vit_b_01ec64.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
|
3 |
+
size 375042383
|
sam_vit_h_4b8939.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
|
3 |
+
size 2564550879
|
tests.py
ADDED
File without changes
|
utils.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import open3d as o3d
|
3 |
+
import open3d as o3d
|
4 |
+
import plotly.express as px
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
from inference import DepthPredictor
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from mpl_toolkits.mplot3d import Axes3D
|
10 |
+
|
11 |
+
|
12 |
+
def create_3d_obj(rgb_image, depth_image, depth=10, path="./image.gltf"):
|
13 |
+
depth_o3d = o3d.geometry.Image(depth_image)
|
14 |
+
image_o3d = o3d.geometry.Image(rgb_image)
|
15 |
+
rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
|
16 |
+
image_o3d, depth_o3d, convert_rgb_to_intensity=False
|
17 |
+
)
|
18 |
+
w = int(depth_image.shape[1])
|
19 |
+
h = int(depth_image.shape[0])
|
20 |
+
|
21 |
+
camera_intrinsic = o3d.camera.PinholeCameraIntrinsic()
|
22 |
+
camera_intrinsic.set_intrinsics(w, h, 500, 500, w / 2, h / 2)
|
23 |
+
|
24 |
+
pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, camera_intrinsic)
|
25 |
+
|
26 |
+
print("normals")
|
27 |
+
pcd.normals = o3d.utility.Vector3dVector(
|
28 |
+
np.zeros((1, 3))
|
29 |
+
) # invalidate existing normals
|
30 |
+
pcd.estimate_normals(
|
31 |
+
search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.01, max_nn=30)
|
32 |
+
)
|
33 |
+
pcd.orient_normals_towards_camera_location(
|
34 |
+
camera_location=np.array([0.0, 0.0, 1000.0])
|
35 |
+
)
|
36 |
+
pcd.transform([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
|
37 |
+
pcd.transform([[-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
|
38 |
+
|
39 |
+
print("run Poisson surface reconstruction")
|
40 |
+
with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
|
41 |
+
mesh_raw, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
|
42 |
+
pcd, depth=depth, width=0, scale=1.1, linear_fit=True
|
43 |
+
)
|
44 |
+
|
45 |
+
voxel_size = max(mesh_raw.get_max_bound() - mesh_raw.get_min_bound()) / 256
|
46 |
+
print(f"voxel_size = {voxel_size:e}")
|
47 |
+
mesh = mesh_raw.simplify_vertex_clustering(
|
48 |
+
voxel_size=voxel_size,
|
49 |
+
contraction=o3d.geometry.SimplificationContraction.Average,
|
50 |
+
)
|
51 |
+
|
52 |
+
# vertices_to_remove = densities < np.quantile(densities, 0.001)
|
53 |
+
# mesh.remove_vertices_by_mask(vertices_to_remove)
|
54 |
+
bbox = pcd.get_axis_aligned_bounding_box()
|
55 |
+
mesh_crop = mesh.crop(bbox)
|
56 |
+
gltf_path = path
|
57 |
+
o3d.io.write_triangle_mesh(gltf_path, mesh_crop, write_triangle_uvs=True)
|
58 |
+
return gltf_path
|
59 |
+
|
60 |
+
|
61 |
+
def create_3d_pc(rgb_image, depth_image, depth=10):
|
62 |
+
depth_image = depth_image.astype(np.float32) # Convert depth map to float32
|
63 |
+
depth_o3d = o3d.geometry.Image(depth_image)
|
64 |
+
image_o3d = o3d.geometry.Image(rgb_image)
|
65 |
+
rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
|
66 |
+
image_o3d, depth_o3d, convert_rgb_to_intensity=False
|
67 |
+
)
|
68 |
+
|
69 |
+
w = int(depth_image.shape[1])
|
70 |
+
h = int(depth_image.shape[0])
|
71 |
+
|
72 |
+
# Specify camera intrinsic parameters (modify based on actual camera)
|
73 |
+
fx = 500
|
74 |
+
fy = 500
|
75 |
+
cx = w / 2
|
76 |
+
cy = h / 2
|
77 |
+
|
78 |
+
camera_intrinsic = o3d.camera.PinholeCameraIntrinsic(w, h, fx, fy, cx, cy)
|
79 |
+
|
80 |
+
pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, camera_intrinsic)
|
81 |
+
|
82 |
+
print("Estimating normals...")
|
83 |
+
pcd.estimate_normals(
|
84 |
+
search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.01, max_nn=30)
|
85 |
+
)
|
86 |
+
pcd.orient_normals_towards_camera_location(
|
87 |
+
camera_location=np.array([0.0, 0.0, 1000.0])
|
88 |
+
)
|
89 |
+
|
90 |
+
# Save the point cloud as a PLY file
|
91 |
+
filename = "pc.pcd"
|
92 |
+
o3d.io.write_point_cloud(filename, pcd)
|
93 |
+
|
94 |
+
return filename # Return the file path where the PLY file is saved
|
95 |
+
|
96 |
+
|
97 |
+
def point_cloud(rgb_image):
|
98 |
+
depth_predictor = DepthPredictor()
|
99 |
+
depth_result = depth_predictor.predict(rgb_image)
|
100 |
+
# Step 2: Create an RGBD image from the RGB and depth image
|
101 |
+
depth_o3d = o3d.geometry.Image(depth_image)
|
102 |
+
image_o3d = o3d.geometry.Image(rgb_image)
|
103 |
+
rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
|
104 |
+
image_o3d, depth_o3d, convert_rgb_to_intensity=False
|
105 |
+
)
|
106 |
+
# Step 3: Create a PointCloud from the RGBD image
|
107 |
+
pcd = o3d.geometry.PointCloud.create_from_rgbd_image(
|
108 |
+
rgbd_image,
|
109 |
+
o3d.camera.PinholeCameraIntrinsic(
|
110 |
+
o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault
|
111 |
+
),
|
112 |
+
)
|
113 |
+
# Step 4: Convert PointCloud data to a NumPy array
|
114 |
+
points = np.asarray(pcd.points)
|
115 |
+
colors = np.asarray(pcd.colors)
|
116 |
+
# Step 5: Create a DataFrame from the NumPy arrays
|
117 |
+
data = {
|
118 |
+
"x": points[:, 0],
|
119 |
+
"y": points[:, 1],
|
120 |
+
"z": points[:, 2],
|
121 |
+
"red": colors[:, 0],
|
122 |
+
"green": colors[:, 1],
|
123 |
+
"blue": colors[:, 2],
|
124 |
+
}
|
125 |
+
df = pd.DataFrame(data)
|
126 |
+
size = np.zeros(len(df))
|
127 |
+
size[:] = 0.01
|
128 |
+
# Step 6: Create a 3D scatter plot using Plotly Express
|
129 |
+
fig = px.scatter_3d(df, x="x", y="y", z="z", color="red", size=size)
|
130 |
+
|
131 |
+
return fig
|
132 |
+
|
133 |
+
|
134 |
+
def array_PCL(rgb_image, depth_image):
|
135 |
+
FX_RGB = 5.1885790117450188e02
|
136 |
+
FY_RGB = 5.1946961112127485e02
|
137 |
+
CX_RGB = 3.2558244941119034e0
|
138 |
+
CY_RGB = 2.5373616633400465e02
|
139 |
+
FX_DEPTH = FX_RGB
|
140 |
+
FY_DEPTH = FY_RGB
|
141 |
+
CX_DEPTH = CX_RGB
|
142 |
+
CY_DEPTH = CY_RGB
|
143 |
+
height = depth_image.shape[0]
|
144 |
+
width = depth_image.shape[1]
|
145 |
+
# compute indices:
|
146 |
+
jj = np.tile(range(width), height)
|
147 |
+
ii = np.repeat(range(height), width)
|
148 |
+
|
149 |
+
# Compute constants:
|
150 |
+
xx = (jj - CX_DEPTH) / FX_DEPTH
|
151 |
+
yy = (ii - CY_DEPTH) / FY_DEPTH
|
152 |
+
|
153 |
+
# transform depth image to vector of z:
|
154 |
+
length = height * width
|
155 |
+
z = depth_image.reshape(length)
|
156 |
+
|
157 |
+
# compute point cloud
|
158 |
+
pcd = np.dstack((xx * z, yy * z, z)).reshape((length, 3))
|
159 |
+
# cam_RGB = np.apply_along_axis(np.linalg.inv(R).dot, 1, pcd) - np.linalg.inv(R).dot(T)
|
160 |
+
xx_rgb = (
|
161 |
+
((rgb_image[:, 0] * FX_RGB) / rgb_image[:, 2] + CX_RGB + width / 2)
|
162 |
+
.astype(int)
|
163 |
+
.clip(0, width - 1)
|
164 |
+
)
|
165 |
+
yy_rgb = (
|
166 |
+
((rgb_image[:, 1] * FY_RGB) / rgb_image[:, 2] + CY_RGB)
|
167 |
+
.astype(int)
|
168 |
+
.clip(0, height - 1)
|
169 |
+
)
|
170 |
+
# colors = rgb_image[yy_rgb, xx_rgb]/255
|
171 |
+
return pcd # , colors
|
172 |
+
|
173 |
+
|
174 |
+
def generate_PCL(image):
|
175 |
+
depth_predictor = DepthPredictor()
|
176 |
+
depth_result = depth_predictor.predict(image)
|
177 |
+
image = np.array(image)
|
178 |
+
pcd = array_PCL(image, depth_result)
|
179 |
+
fig = px.scatter_3d(x=pcd[:, 0], y=pcd[:, 1], z=pcd[:, 2], size_max=0.01)
|
180 |
+
return fig
|
181 |
+
|
182 |
+
|
183 |
+
def plot_PCL(rgb_image, depth_image):
|
184 |
+
pcd, colors = array_PCL(rgb_image, depth_image)
|
185 |
+
fig = px.scatter_3d(
|
186 |
+
x=pcd[:, 0], y=pcd[:, 1], z=pcd[:, 2], color=colors, size_max=0.1
|
187 |
+
)
|
188 |
+
return fig
|
189 |
+
|
190 |
+
|
191 |
+
def PCL3(image):
|
192 |
+
depth_predictor = DepthPredictor()
|
193 |
+
depth_result = depth_predictor.predict(image)
|
194 |
+
image = np.array(image)
|
195 |
+
# Step 2: Create an RGBD image from the RGB and depth image
|
196 |
+
depth_o3d = o3d.geometry.Image(depth_result)
|
197 |
+
image_o3d = o3d.geometry.Image(image)
|
198 |
+
rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
|
199 |
+
image_o3d, depth_o3d, convert_rgb_to_intensity=False
|
200 |
+
)
|
201 |
+
# Step 3: Create a PointCloud from the RGBD image
|
202 |
+
pcd = o3d.geometry.PointCloud.create_from_rgbd_image(
|
203 |
+
rgbd_image,
|
204 |
+
o3d.camera.PinholeCameraIntrinsic(
|
205 |
+
o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault
|
206 |
+
),
|
207 |
+
)
|
208 |
+
# Step 4: Convert PointCloud data to a NumPy array
|
209 |
+
vis = o3d.visualization.Visualizer()
|
210 |
+
vis.add_geometry(pcd)
|
211 |
+
# Step 4: Convert PointCloud data to a NumPy array
|
212 |
+
points = np.asarray(pcd.points)
|
213 |
+
colors = np.asarray(pcd.colors)
|
214 |
+
sizes = np.zeros(colors.shape)
|
215 |
+
sizes[:] = 0.01
|
216 |
+
colors = [tuple(c) for c in colors]
|
217 |
+
fig = plt.figure()
|
218 |
+
# ax = fig.add_subplot(111, projection='3d')
|
219 |
+
ax = Axes3D(fig)
|
220 |
+
print("plotting...")
|
221 |
+
ax.scatter(points[:, 0], points[:, 1], points[:, 2], c=colors, s=0.01)
|
222 |
+
print("Plot Succesful")
|
223 |
+
# data = {'x': points[:, 0], 'y': points[:, 1], 'z': points[:, 2], 'sizes': sizes[:, 0]}
|
224 |
+
# df = pd.DataFrame(data)
|
225 |
+
# Step 6: Create a 3D scatter plot using Plotly Express
|
226 |
+
# fig = px.scatter_3d(df, x='x', y='y', z='z', color=colors, size="sizes")
|
227 |
+
|
228 |
+
return fig
|
229 |
+
|
230 |
+
|
231 |
+
import numpy as np
|