ohjho commited on
Commit
b240372
1 Parent(s): b369bda

tested BTS model and added to the app

Browse files
Files changed (4) hide show
  1. BTS.py +5 -0
  2. BTS_infer.py +62 -12
  3. DPT.py +2 -2
  4. app.py +93 -52
BTS.py CHANGED
@@ -419,6 +419,11 @@ class BtsController:
419
  depth_map = np.asarray(cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB), np.uint8)
420
  return depth_map
421
 
 
 
 
 
 
422
  @staticmethod
423
  def normalize_img(image):
424
  transformation = A.Compose([
419
  depth_map = np.asarray(cv2.cvtColor(depth_map, cv2.COLOR_GRAY2RGB), np.uint8)
420
  return depth_map
421
 
422
+ @staticmethod
423
+ def depth_map_to_grayimg(depth_map):
424
+ depth_map = np.asarray(np.squeeze((255 - torch.clamp_max(depth_map * 4, 250)).byte().numpy()), np.uint8)
425
+ return depth_map
426
+
427
  @staticmethod
428
  def normalize_img(image):
429
  transformation = A.Compose([
BTS_infer.py CHANGED
@@ -1,22 +1,25 @@
1
  import BTS, cv2, torch, gdown, os, zipfile
2
  import numpy as np
 
3
 
4
- def download_model_weight(model_dir, key = "1_mENn0G9YlLAAr3N8DVDt4Hk2SBbo1pl"):
 
5
  if not os.path.isdir(model_dir):
6
  print(f'--- making model directory: {model_dir}')
7
  os.makedirs(model_dir)
 
 
8
  url = f'https://drive.google.com/uc?id={key}&export=download'
9
- tmp_zip_fp = os.path.join(model_dir, 'tmp.zip')
10
 
11
  print(f'--- downloading model weights from {url}')
12
  gdown.download(url, tmp_zip_fp, quiet = True)
13
 
14
- with zipfile.ZipFile(tmp_zip_fp, "r") as zip_ref:
15
- for file in zip_ref.namelist():
16
- zip_ref.extract(file, model_dir)
17
-
18
- os.remove(tmp_zip_fp)
19
- print("--- downloaded model weights done!", flush=True)
20
 
21
  def get_model(model_path = './models/bts_latest'):
22
  if not os.path.isfile(model_path):
@@ -26,11 +29,58 @@ def get_model(model_path = './models/bts_latest'):
26
  model.eval()
27
  return model
28
 
29
- def inference(img_array_rgb, model_obj):
30
- # TODO: add resize max 1080 and multiple of 32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  prediction = model_obj.predict(img_array_rgb, is_channels_first = False, normalize = True)
32
- visual_depth_map = model_obj.depth_map_to_rgbimg(prediction)
33
- return visual_depth_map
 
 
34
 
35
  # prediction = torch.nn.functional.interpolate(
36
  # prediction.unsqueeze(1),
1
  import BTS, cv2, torch, gdown, os, zipfile
2
  import numpy as np
3
+ from PIL import Image
4
 
5
+ def download_model_weight(model_dir,
6
+ file_key_dict = {'bts_latest':"1_mENn0G9YlLAAr3N8DVDt4Hk2SBbo1pl"}):
7
  if not os.path.isdir(model_dir):
8
  print(f'--- making model directory: {model_dir}')
9
  os.makedirs(model_dir)
10
+ fname = list(file_key_dict.keys())[0]
11
+ key = file_key_dict[fname]
12
  url = f'https://drive.google.com/uc?id={key}&export=download'
13
+ tmp_zip_fp = os.path.join(model_dir, fname)
14
 
15
  print(f'--- downloading model weights from {url}')
16
  gdown.download(url, tmp_zip_fp, quiet = True)
17
 
18
+ # with zipfile.ZipFile(tmp_zip_fp, "r") as zip_ref:
19
+ # for file in zip_ref.namelist():
20
+ # zip_ref.extract(file, model_dir)
21
+ # os.remove(tmp_zip_fp)
22
+ print(f"--- downloaded model weights to {tmp_zip_fp}", flush=True)
 
23
 
24
  def get_model(model_path = './models/bts_latest'):
25
  if not os.path.isfile(model_path):
29
  model.eval()
30
  return model
31
 
32
+ def im_max_long_edge(im_np_array, size = 1080, return_pil_im = False,
33
+ resample_algo = Image.LANCZOS, debug = False):
34
+ ''' Return an image whose long edge is no longer than the given size
35
+ Args:
36
+ resample_algo: default to LANCZOS b/c it gives best downscaling quality (per https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table)
37
+ '''
38
+ org_h, org_w, _ = im_np_array.shape
39
+ out_im = None
40
+ if debug:
41
+ print(f'im_max_long_edge: seeing input w,h of {(org_w, org_h)}')
42
+
43
+ if max(org_h, org_w) <= size:
44
+ out_im = im_np_array
45
+ if debug:
46
+ print(f'im_max_long_edge: image dim is smaller than max {size}. no resizing required.')
47
+ else:
48
+ wh_ratio = org_w / org_h
49
+ if org_h > org_w:
50
+ # fix h to size
51
+ h = size
52
+ w = h * wh_ratio
53
+ else:
54
+ # fix w to size
55
+ w = size
56
+ h = w / wh_ratio
57
+ w = int(w)
58
+ h = int(h)
59
+ pil_im = Image.fromarray(im_np_array).resize((w,h), resample = resample_algo)
60
+ out_im = np.array(pil_im)
61
+
62
+ if debug:
63
+ print(f'im_max_long_edge: resizing image to w,h of {(w,h)}')
64
+ return Image.fromarray(out_im) if return_pil_im else out_im
65
+
66
+ def format_depth_map(depth_map, debug = True):
67
+ dmax = depth_map.max()
68
+ dmin = depth_map.min()
69
+ print(f'depth map origin min-max: ({dmin}, {dmax})')
70
+ # formatted = ((depth_map /dmax)* 255).astype('uint8')
71
+
72
+ # min-max normalization
73
+ formatted = (depth_map - depth_map.min())/(depth_map.max()-depth_map.min())
74
+ return (formatted * 255).astype('uint8')
75
+
76
+ def inference(img_array_rgb, model_obj, as_pil = False):
77
+ h, w, _ = img_array_rgb.shape
78
+ img_array_rgb = im_max_long_edge(img_array_rgb,return_pil_im=False, size=720)
79
  prediction = model_obj.predict(img_array_rgb, is_channels_first = False, normalize = True)
80
+ visual_depth_map = model_obj.depth_map_to_grayimg(prediction)
81
+ visual_depth_map = format_depth_map(visual_depth_map)
82
+ visual_depth_map = Image.fromarray(visual_depth_map).resize((w,h),resample = Image.LANCZOS)
83
+ return visual_depth_map if as_pil else np.array(visual_depth_map)
84
 
85
  # prediction = torch.nn.functional.interpolate(
86
  # prediction.unsqueeze(1),
DPT.py CHANGED
@@ -28,7 +28,7 @@ def load_model(model_type = 'DPT_Large'):
28
  'midas': midas, 'device': device, 'transform': transform
29
  }
30
 
31
- def inference(img_array_rgb, model_obj):
32
  '''run DPT model and returns a PIL image'''
33
  # img = cv2.imread(img.name)
34
  # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
@@ -50,7 +50,7 @@ def inference(img_array_rgb, model_obj):
50
  output = prediction.cpu().numpy()
51
  formatted = (output * 255 / np.max(output)).astype('uint8')
52
  img = Image.fromarray(formatted)
53
- return img
54
 
55
  # inputs = gr.inputs.Image(type='file', label="Original Image")
56
  # outputs = gr.outputs.Image(type="pil",label="Output Image")
28
  'midas': midas, 'device': device, 'transform': transform
29
  }
30
 
31
+ def inference(img_array_rgb, model_obj, as_pil = False):
32
  '''run DPT model and returns a PIL image'''
33
  # img = cv2.imread(img.name)
34
  # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
50
  output = prediction.cpu().numpy()
51
  formatted = (output * 255 / np.max(output)).astype('uint8')
52
  img = Image.fromarray(formatted)
53
+ return img if as_pil else formatted
54
 
55
  # inputs = gr.inputs.Image(type='file', label="Original Image")
56
  # outputs = gr.outputs.Image(type="pil",label="Output Image")
app.py CHANGED
@@ -26,93 +26,134 @@ def get_image(st_asset = st.sidebar, as_np_arr = False, extension_list = ['jpg',
26
  im = np.array(im)
27
  return im
28
 
29
- def show_miro_logo(use_column_width = False, width = 100, st_asset= st.sidebar):
30
- logo_url = 'https://miro.medium.com/max/1400/0*qLL-32srlq6Y_iTm.png'
31
- st_asset.image(logo_url, use_column_width = use_column_width, channels = 'BGR', output_format = 'PNG', width = width)
32
 
33
- def im_draw_bbox(pil_im, x0, y0, x1, y1, color = 'black', width = 3, caption = None,
34
- bbv_label_only = False):
35
  '''
36
- draw bounding box on the input image pil_im in-place
 
 
 
 
37
  Args:
38
- color: color name as read by Pillow.ImageColor
39
- use_bbv: use bbox_visualizer
40
  '''
41
- import bbox_visualizer as bbv
42
- if any([type(i)== float for i in [x0,y0,x1,y1]]):
43
- warnings.warn(f'im_draw_bbox: at least one of x0,y0,x1,y1 is of the type float and is converted to int.')
44
- x0 = int(x0)
45
- y0 = int(y0)
46
- x1 = int(x1)
47
- y1 = int(y1)
48
-
49
- if bbv_label_only:
50
- if caption:
51
- im_array = bbv.draw_flag_with_label(np.array(pil_im),
52
- label = caption,
53
- bbox = [x0,y0,x1,y1],
54
- line_color = ImageColor.getrgb(color),
55
- text_bg_color = ImageColor.getrgb(color)
56
- )
57
- else:
58
- raise ValueError(f'im_draw_bbox: bbv_label_only is True but caption is None')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  else:
60
- im_array = bbv.draw_rectangle(np.array(pil_im),
61
- bbox = [x0, y0, x1, y1],
62
- bbox_color = ImageColor.getrgb(color),
63
- thickness = width
64
- )
65
- im_array = bbv.add_label(
66
- im_array, label = caption,
67
- bbox = [x0,y0,x1,y1],
68
- text_bg_color = ImageColor.getrgb(color)
69
- )if caption else im_array
70
- return Image.fromarray(im_array)
71
 
72
  ### Streamlit App ###
 
73
  @st.cache(allow_output_mutation = True)
74
  def get_model_zoo():
75
  model_zoo = {
76
  'DPT': {'infer_func': DPT.inference,'model': DPT.load_model()},
77
- # 'BTS': {'infer_func': BTS_infer.inference,'model': BTS_infer.get_model()}
78
  }
79
  return model_zoo
80
 
81
- @st.cache(suppress_st_warning=True)
82
- def mono_depth(pil_im, model_name):
 
 
83
  s_time = time.time()
84
  model_zoo = get_model_zoo()
85
  infer_func = model_zoo[model_name]['infer_func']
86
  model_obj = model_zoo[model_name]['model']
87
  depth_im = infer_func(img_array_rgb = np.array(pil_im),
88
  model_obj = model_obj)
89
- st.info(f'''
90
- model name: {model_name}\n
91
- inference time: `{round(time.time()-s_time,2)}` seconds\n
92
- depth image shape: {np.array(depth_im).shape}\n
93
- depth image type: {type(depth_im)}
94
- ''')
 
 
 
95
  return depth_im
96
 
97
- def Main():
98
- st.set_page_config(layout = 'wide')
 
 
 
 
 
99
  l_col, r_col = st.columns(2)
100
- show_miro_logo(st_asset = l_col)
101
  with l_col.expander('Monocular Depth: CNN vs Transformers'):
102
  st.info(f'''
103
- Comparsion of two models: [BTS (CNN)](https://github.com/ErenBalatkan/Bts-PyTorch)
104
- and [DPT (Transformer)](https://huggingface.co/Intel/dpt-large)
 
105
  ''')
106
  model_zoo = get_model_zoo()
107
  im = get_image(st_asset = r_col.expander('Input Image', expanded = True), extension_list = ['jpg','jpeg'])
108
  model_name = l_col.selectbox('Pick Model', options = list(model_zoo.keys()))
109
 
110
  if im:
111
- d_im = mono_depth(pil_im = im, model_name=model_name)
 
112
 
113
  l_col, r_col = st.columns(2)
114
  l_col.image(im, caption = 'Input Image')
115
  r_col.image(d_im, caption = 'Depth Map')
 
 
 
 
 
 
 
 
 
 
116
  else:
117
  st.warning(f'please provide an image :point_up:')
118
 
26
  im = np.array(im)
27
  return im
28
 
29
+ def show_miro_logo(use_column_width = False, width = 100, st_asset= st.sidebar, str_color = 'white'):
30
+ logo_url = f'https://miro-ps-bucket-copy.s3.us-west-2.amazonaws.com/storage/jho/web_asset/logo/miro_logo_{str_color}.png'
31
+ st_asset.image(logo_url, use_column_width = use_column_width, channels = 'BGR', output_format = 'PNG', width = width)
32
 
33
+ def im_apply_mask(im_rgb_array, mask_array, get_pil_im = False, bg_rgb_tup = None,
34
+ bg_blur_radius = None, bg_greyscale = False, mask_gblur_radius = 0):
35
  '''
36
+ return either a np array with 4 channels or PIL Image with alpha
37
+ ref: https://stackoverflow.com/questions/47723154/how-to-use-pil-paste-with-mask
38
+ ref: https://stackoverflow.com/questions/62273005/compositing-images-by-blurred-mask-in-numpy
39
+ ref: https://stackoverflow.com/questions/62968174/for-pil-imagefilter-gaussianblur-how-what-kernel-is-used-and-does-the-radius-par
40
+
41
  Args:
42
+ bg_rgb_tup: if given, return a 3-channel image with color background instead of transparent
43
+ bg_blur_radius: if given, return a 3-channel image with GaussianBlur applied to the background
44
  '''
45
+ h, w, c = im_rgb_array.shape
46
+ m_h, m_w = mask_array.shape
47
+
48
+ if not all([h == m_h, w == m_w]):
49
+ raise ValueError(f'im_apply_mask: mask_array size {(m_h, m_w)} must match im_rgb_array {(h, w)}')
50
+
51
+ im = Image.fromarray(im_rgb_array)
52
+
53
+ # convert bitwise mask from np to pillow
54
+ # ref: https://note.nkmk.me/en/python-pillow-paste/
55
+ pil_mask = Image.fromarray(np.uint8(255* mask_array))
56
+ pil_mask = pil_mask.filter(
57
+ ImageFilter.GaussianBlur(radius = mask_gblur_radius)
58
+ ) if mask_gblur_radius > 0 else pil_mask
59
+
60
+ if bg_rgb_tup:
61
+ bg_im = np.zeros([h,w,3], dtype = np.uint8) # black
62
+ bg_im[:,:] = bg_rgb_tup # apply color
63
+
64
+ # old method using just np but doesn't support blurred mask
65
+ # idx = (mask_array != 0)
66
+ # bg_im[idx] = im_rgb_array[idx]
67
+
68
+ bg_im = Image.fromarray(bg_im)
69
+ bg_im.paste(im, mask = pil_mask)
70
+ im = bg_im
71
+ elif bg_blur_radius:
72
+ bg_im = im.copy().filter(
73
+ ImageFilter.GaussianBlur(radius = bg_blur_radius)
74
+ )
75
+ bg_im.paste(im, mask = pil_mask)
76
+ im = bg_im
77
+ elif bg_greyscale:
78
+ bg_im = ImageOps.grayscale(Image.fromarray(im_rgb_array))
79
+ bg_im = np.array(bg_im)
80
+ bg_im = np.stack((bg_im,)*3, axis = -1) # greyscale 1-channel to 3-channel
81
+
82
+ bg_im = Image.fromarray(bg_im)
83
+ bg_im.paste(im, mask = pil_mask)
84
+ im = bg_im
85
  else:
86
+ im.putalpha(pil_mask)
87
+
88
+ return im if get_pil_im else np.array(im)
 
 
 
 
 
 
 
 
89
 
90
  ### Streamlit App ###
91
+ # @st.experimental_memo
92
  @st.cache(allow_output_mutation = True)
93
  def get_model_zoo():
94
  model_zoo = {
95
  'DPT': {'infer_func': DPT.inference,'model': DPT.load_model()},
96
+ 'BTS': {'infer_func': BTS_infer.inference,'model': BTS_infer.get_model()}
97
  }
98
  return model_zoo
99
 
100
+ # @st.experimental_memo(suppress_st_warning=True)
101
+ @st.cache(suppress_st_warning=True,
102
+ hash_funcs={st.delta_generator.DeltaGenerator: lambda _:None})
103
+ def mono_depth(pil_im, model_name, _st_asset = None):
104
  s_time = time.time()
105
  model_zoo = get_model_zoo()
106
  infer_func = model_zoo[model_name]['infer_func']
107
  model_obj = model_zoo[model_name]['model']
108
  depth_im = infer_func(img_array_rgb = np.array(pil_im),
109
  model_obj = model_obj)
110
+ if _st_asset:
111
+ with _st_asset:
112
+ st.info(f'''
113
+ model name: {model_name}\n
114
+ inference time: `{round(time.time()-s_time,2)}` seconds\n
115
+ depth image shape: {np.array(depth_im).shape}\n
116
+ depth image type: {type(depth_im)}\n
117
+ depth map min-max: {depth_im.min()}, {depth_im.max()}
118
+ ''')
119
  return depth_im
120
 
121
+ def Main(): # streamlit version 1.9.2
122
+ st.set_page_config(
123
+ layout = 'wide',
124
+ page_title = 'Monocular Depth',
125
+ page_icon = 'https://miro.io/favicon-32x32.png',
126
+ initial_sidebar_state = 'collapsed'
127
+ )
128
  l_col, r_col = st.columns(2)
129
+ show_miro_logo(st_asset = l_col, str_color = 'purple', width = 200)
130
  with l_col.expander('Monocular Depth: CNN vs Transformers'):
131
  st.info(f'''
132
+ Comparsion of two [SoTA](https://paperswithcode.com/sota/monocular-depth-estimation-on-nyu-depth-v2) models:
133
+ [BTS (CNN), 2019](https://github.com/ErenBalatkan/Bts-PyTorch)
134
+ and [DPT (Transformer), 2021](https://huggingface.co/Intel/dpt-large)
135
  ''')
136
  model_zoo = get_model_zoo()
137
  im = get_image(st_asset = r_col.expander('Input Image', expanded = True), extension_list = ['jpg','jpeg'])
138
  model_name = l_col.selectbox('Pick Model', options = list(model_zoo.keys()))
139
 
140
  if im:
141
+ d_im = mono_depth(pil_im = im, model_name=model_name,
142
+ _st_asset = r_col.expander('inference info'))
143
 
144
  l_col, r_col = st.columns(2)
145
  l_col.image(im, caption = 'Input Image')
146
  r_col.image(d_im, caption = 'Depth Map')
147
+
148
+ with l_col.form('depth filter'):
149
+ min_d, max_d = st.slider('Depth Filter', value = (0,255),
150
+ help = 'smaller value = further away from camera',
151
+ min_value = 0, max_value = 255)
152
+ submitted = st.form_submit_button('filter depth')
153
+ if submitted:
154
+ depth_mask = ((d_im>= min_d) & (d_im<=max_d))
155
+ depth_filter_im = im_apply_mask(np.array(im),mask_array = depth_mask)
156
+ r_col.image(depth_filter_im, caption = 'Depth Filtered Image')
157
  else:
158
  st.warning(f'please provide an image :point_up:')
159