Spaces:
Build error
Build error
import streamlit as st | |
import os, sys, io, time | |
import urllib.request as urllib | |
import numpy as np | |
from PIL import Image | |
import DPT, BTS_infer | |
### Some Utils Functions ### | |
def get_image(st_asset = st.sidebar, as_np_arr = False, extension_list = ['jpg', 'jpeg', 'png']): | |
image_url, image_fh = None, None | |
if st_asset.checkbox('use image URL?'): | |
image_url = st_asset.text_input("Enter Image URL") | |
else: | |
image_fh = st_asset.file_uploader(label = "Update your image", type = extension_list) | |
im = None | |
if image_url: | |
response = urllib.urlopen(image_url) | |
im = Image.open(io.BytesIO(bytearray(response.read()))) | |
elif image_fh: | |
im = Image.open(image_fh) | |
if im and as_np_arr: | |
im = np.array(im) | |
return im | |
def show_miro_logo(use_column_width = False, width = 100, st_asset= st.sidebar, str_color = 'white'): | |
logo_url = f'https://miro-ps-bucket-copy.s3.us-west-2.amazonaws.com/storage/jho/web_asset/logo/miro_logo_{str_color}.png' | |
st_asset.image(logo_url, use_column_width = use_column_width, channels = 'BGR', output_format = 'PNG', width = width) | |
def im_apply_mask(im_rgb_array, mask_array, get_pil_im = False, bg_rgb_tup = None, | |
bg_blur_radius = None, bg_greyscale = False, mask_gblur_radius = 0): | |
''' | |
return either a np array with 4 channels or PIL Image with alpha | |
ref: https://stackoverflow.com/questions/47723154/how-to-use-pil-paste-with-mask | |
ref: https://stackoverflow.com/questions/62273005/compositing-images-by-blurred-mask-in-numpy | |
ref: https://stackoverflow.com/questions/62968174/for-pil-imagefilter-gaussianblur-how-what-kernel-is-used-and-does-the-radius-par | |
Args: | |
bg_rgb_tup: if given, return a 3-channel image with color background instead of transparent | |
bg_blur_radius: if given, return a 3-channel image with GaussianBlur applied to the background | |
''' | |
h, w, c = im_rgb_array.shape | |
m_h, m_w = mask_array.shape | |
if not all([h == m_h, w == m_w]): | |
raise ValueError(f'im_apply_mask: mask_array size {(m_h, m_w)} must match im_rgb_array {(h, w)}') | |
im = Image.fromarray(im_rgb_array) | |
# convert bitwise mask from np to pillow | |
# ref: https://note.nkmk.me/en/python-pillow-paste/ | |
pil_mask = Image.fromarray(np.uint8(255* mask_array)) | |
pil_mask = pil_mask.filter( | |
ImageFilter.GaussianBlur(radius = mask_gblur_radius) | |
) if mask_gblur_radius > 0 else pil_mask | |
if bg_rgb_tup: | |
bg_im = np.zeros([h,w,3], dtype = np.uint8) # black | |
bg_im[:,:] = bg_rgb_tup # apply color | |
# old method using just np but doesn't support blurred mask | |
# idx = (mask_array != 0) | |
# bg_im[idx] = im_rgb_array[idx] | |
bg_im = Image.fromarray(bg_im) | |
bg_im.paste(im, mask = pil_mask) | |
im = bg_im | |
elif bg_blur_radius: | |
bg_im = im.copy().filter( | |
ImageFilter.GaussianBlur(radius = bg_blur_radius) | |
) | |
bg_im.paste(im, mask = pil_mask) | |
im = bg_im | |
elif bg_greyscale: | |
bg_im = ImageOps.grayscale(Image.fromarray(im_rgb_array)) | |
bg_im = np.array(bg_im) | |
bg_im = np.stack((bg_im,)*3, axis = -1) # greyscale 1-channel to 3-channel | |
bg_im = Image.fromarray(bg_im) | |
bg_im.paste(im, mask = pil_mask) | |
im = bg_im | |
else: | |
im.putalpha(pil_mask) | |
return im if get_pil_im else np.array(im) | |
### Streamlit App ### | |
# @st.experimental_memo | |
def get_model_zoo(): | |
model_zoo = { | |
'DPT': {'infer_func': DPT.inference,'model': DPT.load_model()}, | |
'BTS': {'infer_func': BTS_infer.inference,'model': BTS_infer.get_model()} | |
} | |
return model_zoo | |
# @st.experimental_memo(suppress_st_warning=True) | |
def mono_depth(pil_im, model_name, _st_asset = None): | |
s_time = time.time() | |
model_zoo = get_model_zoo() | |
infer_func = model_zoo[model_name]['infer_func'] | |
model_obj = model_zoo[model_name]['model'] | |
depth_im = infer_func(img_array_rgb = np.array(pil_im), | |
model_obj = model_obj) | |
if _st_asset: | |
with _st_asset: | |
st.info(f''' | |
model name: {model_name}\n | |
inference time: `{round(time.time()-s_time,2)}` seconds\n | |
depth image shape: {np.array(depth_im).shape}\n | |
depth image type: {type(depth_im)}\n | |
depth map min-max: {depth_im.min()}, {depth_im.max()} | |
''') | |
return depth_im | |
def Main(): # streamlit version 1.9.2 | |
st.set_page_config( | |
layout = 'wide', | |
page_title = 'Monocular Depth', | |
page_icon = 'https://miro.io/favicon-32x32.png', | |
initial_sidebar_state = 'collapsed' | |
) | |
l_col, r_col = st.columns(2) | |
show_miro_logo(st_asset = l_col, str_color = 'purple', width = 200) | |
with l_col.expander('Monocular Depth: CNN vs Transformers'): | |
st.info(f''' | |
Comparsion of two [SoTA](https://paperswithcode.com/sota/monocular-depth-estimation-on-nyu-depth-v2) models: | |
[BTS (CNN), 2019](https://github.com/ErenBalatkan/Bts-PyTorch) | |
and [DPT (Transformer), 2021](https://huggingface.co/Intel/dpt-large) | |
''') | |
model_zoo = get_model_zoo() | |
im = get_image(st_asset = r_col.expander('Input Image', expanded = True), extension_list = ['jpg','jpeg']) | |
model_name = l_col.selectbox('Pick Model', options = list(model_zoo.keys())) | |
if im: | |
d_im = mono_depth(pil_im = im, model_name=model_name, | |
_st_asset = r_col.expander('inference info')) | |
l_col, r_col = st.columns(2) | |
l_col.image(im, caption = 'Input Image') | |
r_col.image(d_im, caption = 'Depth Map') | |
with l_col.form('depth filter'): | |
min_d, max_d = st.slider('Depth Filter', value = (0,255), | |
help = 'smaller value = further away from camera', | |
min_value = 0, max_value = 255) | |
submitted = st.form_submit_button('filter depth') | |
if submitted: | |
depth_mask = ((d_im>= min_d) & (d_im<=max_d)) | |
depth_filter_im = im_apply_mask(np.array(im),mask_array = depth_mask) | |
r_col.image(depth_filter_im, caption = 'Depth Filtered Image') | |
else: | |
st.warning(f'please provide an image :point_up:') | |
if __name__ == '__main__': | |
Main() | |