Spaces:
Build error
Build error
ohjho
commited on
Commit
•
b240372
1
Parent(s):
b369bda
tested BTS model and added to the app
Browse files
BTS.py
CHANGED
@@ -419,6 +419,11 @@ class BtsController:
|
|
419 |
depth_map = np.asarray(cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB), np.uint8)
|
420 |
return depth_map
|
421 |
|
|
|
|
|
|
|
|
|
|
|
422 |
@staticmethod
|
423 |
def normalize_img(image):
|
424 |
transformation = A.Compose([
|
419 |
depth_map = np.asarray(cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB), np.uint8)
|
420 |
return depth_map
|
421 |
|
422 |
+
@staticmethod
|
423 |
+
def depth_map_to_grayimg(depth_map):
|
424 |
+
depth_map = np.asarray(np.squeeze((255 - torch.clamp_max(depth_map * 4, 250)).byte().numpy()), np.uint8)
|
425 |
+
return depth_map
|
426 |
+
|
427 |
@staticmethod
|
428 |
def normalize_img(image):
|
429 |
transformation = A.Compose([
|
BTS_infer.py
CHANGED
@@ -1,22 +1,25 @@
|
|
1 |
import BTS, cv2, torch, gdown, os, zipfile
|
2 |
import numpy as np
|
|
|
3 |
|
4 |
-
def download_model_weight(model_dir,
|
|
|
5 |
if not os.path.isdir(model_dir):
|
6 |
print(f'--- making model directory: {model_dir}')
|
7 |
os.makedirs(model_dir)
|
|
|
|
|
8 |
url = f'https://drive.google.com/uc?id={key}&export=download'
|
9 |
-
tmp_zip_fp = os.path.join(model_dir,
|
10 |
|
11 |
print(f'--- downloading model weights from {url}')
|
12 |
gdown.download(url, tmp_zip_fp, quiet = True)
|
13 |
|
14 |
-
with zipfile.ZipFile(tmp_zip_fp, "r") as zip_ref:
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
print("--- downloaded model weights done!", flush=True)
|
20 |
|
21 |
def get_model(model_path = './models/bts_latest'):
|
22 |
if not os.path.isfile(model_path):
|
@@ -26,11 +29,58 @@ def get_model(model_path = './models/bts_latest'):
|
|
26 |
model.eval()
|
27 |
return model
|
28 |
|
29 |
-
def
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
prediction = model_obj.predict(img_array_rgb, is_channels_first = False, normalize = True)
|
32 |
-
visual_depth_map = model_obj.
|
33 |
-
|
|
|
|
|
34 |
|
35 |
# prediction = torch.nn.functional.interpolate(
|
36 |
# prediction.unsqueeze(1),
|
1 |
import BTS, cv2, torch, gdown, os, zipfile
|
2 |
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
|
5 |
+
def download_model_weight(model_dir,
|
6 |
+
file_key_dict = {'bts_latest':"1_mENn0G9YlLAAr3N8DVDt4Hk2SBbo1pl"}):
|
7 |
if not os.path.isdir(model_dir):
|
8 |
print(f'--- making model directory: {model_dir}')
|
9 |
os.makedirs(model_dir)
|
10 |
+
fname = list(file_key_dict.keys())[0]
|
11 |
+
key = file_key_dict[fname]
|
12 |
url = f'https://drive.google.com/uc?id={key}&export=download'
|
13 |
+
tmp_zip_fp = os.path.join(model_dir, fname)
|
14 |
|
15 |
print(f'--- downloading model weights from {url}')
|
16 |
gdown.download(url, tmp_zip_fp, quiet = True)
|
17 |
|
18 |
+
# with zipfile.ZipFile(tmp_zip_fp, "r") as zip_ref:
|
19 |
+
# for file in zip_ref.namelist():
|
20 |
+
# zip_ref.extract(file, model_dir)
|
21 |
+
# os.remove(tmp_zip_fp)
|
22 |
+
print(f"--- downloaded model weights to {tmp_zip_fp}", flush=True)
|
|
|
23 |
|
24 |
def get_model(model_path = './models/bts_latest'):
|
25 |
if not os.path.isfile(model_path):
|
29 |
model.eval()
|
30 |
return model
|
31 |
|
32 |
+
def im_max_long_edge(im_np_array, size = 1080, return_pil_im = False,
|
33 |
+
resample_algo = Image.LANCZOS, debug = False):
|
34 |
+
''' Return an image whose long edge is no longer than the given size
|
35 |
+
Args:
|
36 |
+
resample_algo: default to LANCZOS b/c it gives best downscaling quality (per https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table)
|
37 |
+
'''
|
38 |
+
org_h, org_w, _ = im_np_array.shape
|
39 |
+
out_im = None
|
40 |
+
if debug:
|
41 |
+
print(f'im_max_long_edge: seeing input w,h of {(org_w, org_h)}')
|
42 |
+
|
43 |
+
if max(org_h, org_w) <= size:
|
44 |
+
out_im = im_np_array
|
45 |
+
if debug:
|
46 |
+
print(f'im_max_long_edge: image dim is smaller than max {size}. no resizing required.')
|
47 |
+
else:
|
48 |
+
wh_ratio = org_w / org_h
|
49 |
+
if org_h > org_w:
|
50 |
+
# fix h to size
|
51 |
+
h = size
|
52 |
+
w = h * wh_ratio
|
53 |
+
else:
|
54 |
+
# fix w to size
|
55 |
+
w = size
|
56 |
+
h = w / wh_ratio
|
57 |
+
w = int(w)
|
58 |
+
h = int(h)
|
59 |
+
pil_im = Image.fromarray(im_np_array).resize((w,h), resample = resample_algo)
|
60 |
+
out_im = np.array(pil_im)
|
61 |
+
|
62 |
+
if debug:
|
63 |
+
print(f'im_max_long_edge: resizing image to w,h of {(w,h)}')
|
64 |
+
return Image.fromarray(out_im) if return_pil_im else out_im
|
65 |
+
|
66 |
+
def format_depth_map(depth_map, debug = True):
|
67 |
+
dmax = depth_map.max()
|
68 |
+
dmin = depth_map.min()
|
69 |
+
print(f'depth map origin min-max: ({dmin}, {dmax})')
|
70 |
+
# formatted = ((depth_map /dmax)* 255).astype('uint8')
|
71 |
+
|
72 |
+
# min-max normalization
|
73 |
+
formatted = (depth_map - depth_map.min())/(depth_map.max()-depth_map.min())
|
74 |
+
return (formatted * 255).astype('uint8')
|
75 |
+
|
76 |
+
def inference(img_array_rgb, model_obj, as_pil = False):
|
77 |
+
h, w, _ = img_array_rgb.shape
|
78 |
+
img_array_rgb = im_max_long_edge(img_array_rgb,return_pil_im=False, size=720)
|
79 |
prediction = model_obj.predict(img_array_rgb, is_channels_first = False, normalize = True)
|
80 |
+
visual_depth_map = model_obj.depth_map_to_grayimg(prediction)
|
81 |
+
visual_depth_map = format_depth_map(visual_depth_map)
|
82 |
+
visual_depth_map = Image.fromarray(visual_depth_map).resize((w,h),resample = Image.LANCZOS)
|
83 |
+
return visual_depth_map if as_pil else np.array(visual_depth_map)
|
84 |
|
85 |
# prediction = torch.nn.functional.interpolate(
|
86 |
# prediction.unsqueeze(1),
|
DPT.py
CHANGED
@@ -28,7 +28,7 @@ def load_model(model_type = 'DPT_Large'):
|
|
28 |
'midas': midas, 'device': device, 'transform': transform
|
29 |
}
|
30 |
|
31 |
-
def inference(img_array_rgb, model_obj):
|
32 |
'''run DPT model and returns a PIL image'''
|
33 |
# img = cv2.imread(img.name)
|
34 |
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
@@ -50,7 +50,7 @@ def inference(img_array_rgb, model_obj):
|
|
50 |
output = prediction.cpu().numpy()
|
51 |
formatted = (output * 255 / np.max(output)).astype('uint8')
|
52 |
img = Image.fromarray(formatted)
|
53 |
-
return img
|
54 |
|
55 |
# inputs = gr.inputs.Image(type='file', label="Original Image")
|
56 |
# outputs = gr.outputs.Image(type="pil",label="Output Image")
|
28 |
'midas': midas, 'device': device, 'transform': transform
|
29 |
}
|
30 |
|
31 |
+
def inference(img_array_rgb, model_obj, as_pil = False):
|
32 |
'''run DPT model and returns a PIL image'''
|
33 |
# img = cv2.imread(img.name)
|
34 |
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
50 |
output = prediction.cpu().numpy()
|
51 |
formatted = (output * 255 / np.max(output)).astype('uint8')
|
52 |
img = Image.fromarray(formatted)
|
53 |
+
return img if as_pil else formatted
|
54 |
|
55 |
# inputs = gr.inputs.Image(type='file', label="Original Image")
|
56 |
# outputs = gr.outputs.Image(type="pil",label="Output Image")
|
app.py
CHANGED
@@ -26,93 +26,134 @@ def get_image(st_asset = st.sidebar, as_np_arr = False, extension_list = ['jpg',
|
|
26 |
im = np.array(im)
|
27 |
return im
|
28 |
|
29 |
-
def show_miro_logo(use_column_width = False, width = 100, st_asset= st.sidebar):
|
30 |
-
|
31 |
-
|
32 |
|
33 |
-
def
|
34 |
-
|
35 |
'''
|
36 |
-
|
|
|
|
|
|
|
|
|
37 |
Args:
|
38 |
-
|
39 |
-
|
40 |
'''
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
else:
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
thickness = width
|
64 |
-
)
|
65 |
-
im_array = bbv.add_label(
|
66 |
-
im_array, label = caption,
|
67 |
-
bbox = [x0,y0,x1,y1],
|
68 |
-
text_bg_color = ImageColor.getrgb(color)
|
69 |
-
)if caption else im_array
|
70 |
-
return Image.fromarray(im_array)
|
71 |
|
72 |
### Streamlit App ###
|
|
|
73 |
@st.cache(allow_output_mutation = True)
|
74 |
def get_model_zoo():
|
75 |
model_zoo = {
|
76 |
'DPT': {'infer_func': DPT.inference,'model': DPT.load_model()},
|
77 |
-
|
78 |
}
|
79 |
return model_zoo
|
80 |
|
81 |
-
@st.
|
82 |
-
|
|
|
|
|
83 |
s_time = time.time()
|
84 |
model_zoo = get_model_zoo()
|
85 |
infer_func = model_zoo[model_name]['infer_func']
|
86 |
model_obj = model_zoo[model_name]['model']
|
87 |
depth_im = infer_func(img_array_rgb = np.array(pil_im),
|
88 |
model_obj = model_obj)
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
95 |
return depth_im
|
96 |
|
97 |
-
def Main():
|
98 |
-
st.set_page_config(
|
|
|
|
|
|
|
|
|
|
|
99 |
l_col, r_col = st.columns(2)
|
100 |
-
show_miro_logo(st_asset = l_col)
|
101 |
with l_col.expander('Monocular Depth: CNN vs Transformers'):
|
102 |
st.info(f'''
|
103 |
-
Comparsion of two
|
104 |
-
|
|
|
105 |
''')
|
106 |
model_zoo = get_model_zoo()
|
107 |
im = get_image(st_asset = r_col.expander('Input Image', expanded = True), extension_list = ['jpg','jpeg'])
|
108 |
model_name = l_col.selectbox('Pick Model', options = list(model_zoo.keys()))
|
109 |
|
110 |
if im:
|
111 |
-
d_im = mono_depth(pil_im = im, model_name=model_name
|
|
|
112 |
|
113 |
l_col, r_col = st.columns(2)
|
114 |
l_col.image(im, caption = 'Input Image')
|
115 |
r_col.image(d_im, caption = 'Depth Map')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
else:
|
117 |
st.warning(f'please provide an image :point_up:')
|
118 |
|
26 |
im = np.array(im)
|
27 |
return im
|
28 |
|
29 |
+
def show_miro_logo(use_column_width = False, width = 100, st_asset= st.sidebar, str_color = 'white'):
|
30 |
+
logo_url = f'https://miro-ps-bucket-copy.s3.us-west-2.amazonaws.com/storage/jho/web_asset/logo/miro_logo_{str_color}.png'
|
31 |
+
st_asset.image(logo_url, use_column_width = use_column_width, channels = 'BGR', output_format = 'PNG', width = width)
|
32 |
|
33 |
+
def im_apply_mask(im_rgb_array, mask_array, get_pil_im = False, bg_rgb_tup = None,
|
34 |
+
bg_blur_radius = None, bg_greyscale = False, mask_gblur_radius = 0):
|
35 |
'''
|
36 |
+
return either a np array with 4 channels or PIL Image with alpha
|
37 |
+
ref: https://stackoverflow.com/questions/47723154/how-to-use-pil-paste-with-mask
|
38 |
+
ref: https://stackoverflow.com/questions/62273005/compositing-images-by-blurred-mask-in-numpy
|
39 |
+
ref: https://stackoverflow.com/questions/62968174/for-pil-imagefilter-gaussianblur-how-what-kernel-is-used-and-does-the-radius-par
|
40 |
+
|
41 |
Args:
|
42 |
+
bg_rgb_tup: if given, return a 3-channel image with color background instead of transparent
|
43 |
+
bg_blur_radius: if given, return a 3-channel image with GaussianBlur applied to the background
|
44 |
'''
|
45 |
+
h, w, c = im_rgb_array.shape
|
46 |
+
m_h, m_w = mask_array.shape
|
47 |
+
|
48 |
+
if not all([h == m_h, w == m_w]):
|
49 |
+
raise ValueError(f'im_apply_mask: mask_array size {(m_h, m_w)} must match im_rgb_array {(h, w)}')
|
50 |
+
|
51 |
+
im = Image.fromarray(im_rgb_array)
|
52 |
+
|
53 |
+
# convert bitwise mask from np to pillow
|
54 |
+
# ref: https://note.nkmk.me/en/python-pillow-paste/
|
55 |
+
pil_mask = Image.fromarray(np.uint8(255* mask_array))
|
56 |
+
pil_mask = pil_mask.filter(
|
57 |
+
ImageFilter.GaussianBlur(radius = mask_gblur_radius)
|
58 |
+
) if mask_gblur_radius > 0 else pil_mask
|
59 |
+
|
60 |
+
if bg_rgb_tup:
|
61 |
+
bg_im = np.zeros([h,w,3], dtype = np.uint8) # black
|
62 |
+
bg_im[:,:] = bg_rgb_tup # apply color
|
63 |
+
|
64 |
+
# old method using just np but doesn't support blurred mask
|
65 |
+
# idx = (mask_array != 0)
|
66 |
+
# bg_im[idx] = im_rgb_array[idx]
|
67 |
+
|
68 |
+
bg_im = Image.fromarray(bg_im)
|
69 |
+
bg_im.paste(im, mask = pil_mask)
|
70 |
+
im = bg_im
|
71 |
+
elif bg_blur_radius:
|
72 |
+
bg_im = im.copy().filter(
|
73 |
+
ImageFilter.GaussianBlur(radius = bg_blur_radius)
|
74 |
+
)
|
75 |
+
bg_im.paste(im, mask = pil_mask)
|
76 |
+
im = bg_im
|
77 |
+
elif bg_greyscale:
|
78 |
+
bg_im = ImageOps.grayscale(Image.fromarray(im_rgb_array))
|
79 |
+
bg_im = np.array(bg_im)
|
80 |
+
bg_im = np.stack((bg_im,)*3, axis = -1) # greyscale 1-channel to 3-channel
|
81 |
+
|
82 |
+
bg_im = Image.fromarray(bg_im)
|
83 |
+
bg_im.paste(im, mask = pil_mask)
|
84 |
+
im = bg_im
|
85 |
else:
|
86 |
+
im.putalpha(pil_mask)
|
87 |
+
|
88 |
+
return im if get_pil_im else np.array(im)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
### Streamlit App ###
|
91 |
+
# @st.experimental_memo
|
92 |
@st.cache(allow_output_mutation = True)
|
93 |
def get_model_zoo():
|
94 |
model_zoo = {
|
95 |
'DPT': {'infer_func': DPT.inference,'model': DPT.load_model()},
|
96 |
+
'BTS': {'infer_func': BTS_infer.inference,'model': BTS_infer.get_model()}
|
97 |
}
|
98 |
return model_zoo
|
99 |
|
100 |
+
# @st.experimental_memo(suppress_st_warning=True)
|
101 |
+
@st.cache(suppress_st_warning=True,
|
102 |
+
hash_funcs={st.delta_generator.DeltaGenerator: lambda _:None})
|
103 |
+
def mono_depth(pil_im, model_name, _st_asset = None):
|
104 |
s_time = time.time()
|
105 |
model_zoo = get_model_zoo()
|
106 |
infer_func = model_zoo[model_name]['infer_func']
|
107 |
model_obj = model_zoo[model_name]['model']
|
108 |
depth_im = infer_func(img_array_rgb = np.array(pil_im),
|
109 |
model_obj = model_obj)
|
110 |
+
if _st_asset:
|
111 |
+
with _st_asset:
|
112 |
+
st.info(f'''
|
113 |
+
model name: {model_name}\n
|
114 |
+
inference time: `{round(time.time()-s_time,2)}` seconds\n
|
115 |
+
depth image shape: {np.array(depth_im).shape}\n
|
116 |
+
depth image type: {type(depth_im)}\n
|
117 |
+
depth map min-max: {depth_im.min()}, {depth_im.max()}
|
118 |
+
''')
|
119 |
return depth_im
|
120 |
|
121 |
+
def Main(): # streamlit version 1.9.2
|
122 |
+
st.set_page_config(
|
123 |
+
layout = 'wide',
|
124 |
+
page_title = 'Monocular Depth',
|
125 |
+
page_icon = 'https://miro.io/favicon-32x32.png',
|
126 |
+
initial_sidebar_state = 'collapsed'
|
127 |
+
)
|
128 |
l_col, r_col = st.columns(2)
|
129 |
+
show_miro_logo(st_asset = l_col, str_color = 'purple', width = 200)
|
130 |
with l_col.expander('Monocular Depth: CNN vs Transformers'):
|
131 |
st.info(f'''
|
132 |
+
Comparsion of two [SoTA](https://paperswithcode.com/sota/monocular-depth-estimation-on-nyu-depth-v2) models:
|
133 |
+
[BTS (CNN), 2019](https://github.com/ErenBalatkan/Bts-PyTorch)
|
134 |
+
and [DPT (Transformer), 2021](https://huggingface.co/Intel/dpt-large)
|
135 |
''')
|
136 |
model_zoo = get_model_zoo()
|
137 |
im = get_image(st_asset = r_col.expander('Input Image', expanded = True), extension_list = ['jpg','jpeg'])
|
138 |
model_name = l_col.selectbox('Pick Model', options = list(model_zoo.keys()))
|
139 |
|
140 |
if im:
|
141 |
+
d_im = mono_depth(pil_im = im, model_name=model_name,
|
142 |
+
_st_asset = r_col.expander('inference info'))
|
143 |
|
144 |
l_col, r_col = st.columns(2)
|
145 |
l_col.image(im, caption = 'Input Image')
|
146 |
r_col.image(d_im, caption = 'Depth Map')
|
147 |
+
|
148 |
+
with l_col.form('depth filter'):
|
149 |
+
min_d, max_d = st.slider('Depth Filter', value = (0,255),
|
150 |
+
help = 'smaller value = further away from camera',
|
151 |
+
min_value = 0, max_value = 255)
|
152 |
+
submitted = st.form_submit_button('filter depth')
|
153 |
+
if submitted:
|
154 |
+
depth_mask = ((d_im>= min_d) & (d_im<=max_d))
|
155 |
+
depth_filter_im = im_apply_mask(np.array(im),mask_array = depth_mask)
|
156 |
+
r_col.image(depth_filter_im, caption = 'Depth Filtered Image')
|
157 |
else:
|
158 |
st.warning(f'please provide an image :point_up:')
|
159 |
|