File size: 5,820 Bytes
dfcd969
 
a327756
dfcd969
 
 
 
a327756
dfcd969
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b240372
 
 
dfcd969
b240372
 
dfcd969
b240372
 
 
 
 
dfcd969
b240372
 
dfcd969
b240372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfcd969
b240372
 
 
dfcd969
 
b240372
a327756
 
 
 
b240372
a327756
 
dfcd969
b240372
 
 
 
a327756
 
 
 
 
 
b240372
 
 
 
 
 
 
 
 
dfcd969
 
b240372
 
 
 
 
 
 
dfcd969
b240372
dfcd969
 
b240372
 
 
dfcd969
a327756
dfcd969
a327756
dfcd969
 
b240372
 
dfcd969
 
 
a327756
b240372
 
 
 
 
 
 
 
 
 
dfcd969
 
 
 
a327756
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import streamlit as st

import os, sys, io, time
import urllib.request as urllib
import numpy as np
from PIL import Image

import DPT, BTS_infer

### Some Utils Functions ###
def get_image(st_asset = st.sidebar, as_np_arr = False, extension_list = ['jpg', 'jpeg', 'png']):
	image_url, image_fh = None, None
	if st_asset.checkbox('use image URL?'):
		image_url = st_asset.text_input("Enter Image URL")
	else:
		image_fh = st_asset.file_uploader(label = "Update your image", type = extension_list)

	im = None
	if image_url:
		response = urllib.urlopen(image_url)
		im = Image.open(io.BytesIO(bytearray(response.read())))
	elif image_fh:
		im = Image.open(image_fh)

	if im and as_np_arr:
		im = np.array(im)
	return im

def show_miro_logo(use_column_width = False, width = 100, st_asset= st.sidebar, str_color = 'white'):
    logo_url = f'https://miro-ps-bucket-copy.s3.us-west-2.amazonaws.com/storage/jho/web_asset/logo/miro_logo_{str_color}.png'
    st_asset.image(logo_url, use_column_width = use_column_width, channels = 'BGR', output_format = 'PNG', width = width)

def im_apply_mask(im_rgb_array, mask_array, get_pil_im = False, bg_rgb_tup = None,
	bg_blur_radius = None, bg_greyscale = False, mask_gblur_radius = 0):
	'''
	return either a np array with 4 channels or PIL Image with alpha
	ref: https://stackoverflow.com/questions/47723154/how-to-use-pil-paste-with-mask
	ref: https://stackoverflow.com/questions/62273005/compositing-images-by-blurred-mask-in-numpy
	ref: https://stackoverflow.com/questions/62968174/for-pil-imagefilter-gaussianblur-how-what-kernel-is-used-and-does-the-radius-par

	Args:
		bg_rgb_tup: if given, return a 3-channel image with color background instead of transparent
		bg_blur_radius: if given, return a 3-channel image with GaussianBlur applied to the background
	'''
	h, w, c = im_rgb_array.shape
	m_h, m_w = mask_array.shape

	if not all([h == m_h, w == m_w]):
		raise ValueError(f'im_apply_mask: mask_array size {(m_h, m_w)} must match im_rgb_array {(h, w)}')

	im = Image.fromarray(im_rgb_array)

	# convert bitwise mask from np to pillow
	# ref: https://note.nkmk.me/en/python-pillow-paste/
	pil_mask = Image.fromarray(np.uint8(255* mask_array))
	pil_mask = pil_mask.filter(
					ImageFilter.GaussianBlur(radius = mask_gblur_radius)
				) if mask_gblur_radius > 0 else pil_mask

	if bg_rgb_tup:
		bg_im = np.zeros([h,w,3], dtype = np.uint8) # black
		bg_im[:,:] = bg_rgb_tup						# apply color

		# old method using just np but doesn't support blurred mask
		# idx = (mask_array != 0)
		# bg_im[idx] = im_rgb_array[idx]

		bg_im = Image.fromarray(bg_im)
		bg_im.paste(im, mask = pil_mask)
		im = bg_im
	elif bg_blur_radius:
		bg_im = im.copy().filter(
					ImageFilter.GaussianBlur(radius = bg_blur_radius)
				)
		bg_im.paste(im, mask = pil_mask)
		im = bg_im
	elif bg_greyscale:
		bg_im = ImageOps.grayscale(Image.fromarray(im_rgb_array))
		bg_im = np.array(bg_im)
		bg_im = np.stack((bg_im,)*3, axis = -1) 	# greyscale 1-channel to 3-channel

		bg_im =  Image.fromarray(bg_im)
		bg_im.paste(im, mask = pil_mask)
		im = bg_im
	else:
		im.putalpha(pil_mask)

	return im if get_pil_im else np.array(im)

### Streamlit App ###
# @st.experimental_memo
@st.cache(allow_output_mutation = True)
def get_model_zoo():
	model_zoo = {
		'DPT': {'infer_func': DPT.inference,'model': DPT.load_model()},
		'BTS': {'infer_func': BTS_infer.inference,'model': BTS_infer.get_model()}
	}
	return model_zoo

# @st.experimental_memo(suppress_st_warning=True)
@st.cache(suppress_st_warning=True,
	hash_funcs={st.delta_generator.DeltaGenerator: lambda _:None})
def mono_depth(pil_im, model_name, _st_asset = None):
	s_time = time.time()
	model_zoo = get_model_zoo()
	infer_func = model_zoo[model_name]['infer_func']
	model_obj = model_zoo[model_name]['model']
	depth_im = infer_func(img_array_rgb = np.array(pil_im),
					model_obj = model_obj)
	if _st_asset:
		with _st_asset:
			st.info(f'''
				model name: {model_name}\n
				inference time: `{round(time.time()-s_time,2)}` seconds\n
				depth image shape: {np.array(depth_im).shape}\n
				depth image type: {type(depth_im)}\n
				depth map min-max: {depth_im.min()}, {depth_im.max()}
				''')
	return depth_im

def Main(): # streamlit version 1.9.2
	st.set_page_config(
		layout = 'wide',
		page_title = 'Monocular Depth',
		page_icon = 'https://miro.io/favicon-32x32.png',
		initial_sidebar_state = 'collapsed'
		)
	l_col, r_col = st.columns(2)
	show_miro_logo(st_asset = l_col, str_color = 'purple', width = 200)
	with l_col.expander('Monocular Depth: CNN vs Transformers'):
		st.info(f'''
		Comparsion of two [SoTA](https://paperswithcode.com/sota/monocular-depth-estimation-on-nyu-depth-v2) models:
		[BTS (CNN), 2019](https://github.com/ErenBalatkan/Bts-PyTorch)
		and [DPT (Transformer), 2021](https://huggingface.co/Intel/dpt-large)
		''')
	model_zoo = get_model_zoo()
	im = get_image(st_asset = r_col.expander('Input Image', expanded = True), extension_list = ['jpg','jpeg'])
	model_name = l_col.selectbox('Pick Model', options = list(model_zoo.keys()))

	if im:
		d_im = mono_depth(pil_im = im, model_name=model_name,
				_st_asset = r_col.expander('inference info'))

		l_col, r_col = st.columns(2)
		l_col.image(im, caption = 'Input Image')
		r_col.image(d_im, caption = 'Depth Map')

		with l_col.form('depth filter'):
			min_d, max_d = st.slider('Depth Filter', value = (0,255),
								help = 'smaller value = further away from camera',
								min_value = 0, max_value = 255)
			submitted = st.form_submit_button('filter depth')
		if submitted:
			depth_mask = ((d_im>= min_d) & (d_im<=max_d))
			depth_filter_im = im_apply_mask(np.array(im),mask_array = depth_mask)
			r_col.image(depth_filter_im, caption = 'Depth Filtered Image')
	else:
		st.warning(f'please provide an image :point_up:')

if __name__ == '__main__':
	Main()