MonocularDepth / app.py
ohjho
testing DPT app
dfcd969
raw history blame
No virus
3.23 kB
import streamlit as st
import os, sys, io
import urllib.request as urllib
import numpy as np
from PIL import Image
import DPT
### Some Utils Functions ###
def get_image(st_asset = st.sidebar, as_np_arr = False, extension_list = ['jpg', 'jpeg', 'png']):
image_url, image_fh = None, None
if st_asset.checkbox('use image URL?'):
image_url = st_asset.text_input("Enter Image URL")
else:
image_fh = st_asset.file_uploader(label = "Update your image", type = extension_list)
im = None
if image_url:
response = urllib.urlopen(image_url)
im = Image.open(io.BytesIO(bytearray(response.read())))
elif image_fh:
im = Image.open(image_fh)
if im and as_np_arr:
im = np.array(im)
return im
def show_miro_logo(use_column_width = False, width = 100, st_asset= st.sidebar):
logo_url = 'https://miro.medium.com/max/1400/0*qLL-32srlq6Y_iTm.png'
st_asset.image(logo_url, use_column_width = use_column_width, channels = 'BGR', output_format = 'PNG', width = width)
def im_draw_bbox(pil_im, x0, y0, x1, y1, color = 'black', width = 3, caption = None,
bbv_label_only = False):
'''
draw bounding box on the input image pil_im in-place
Args:
color: color name as read by Pillow.ImageColor
use_bbv: use bbox_visualizer
'''
import bbox_visualizer as bbv
if any([type(i)== float for i in [x0,y0,x1,y1]]):
warnings.warn(f'im_draw_bbox: at least one of x0,y0,x1,y1 is of the type float and is converted to int.')
x0 = int(x0)
y0 = int(y0)
x1 = int(x1)
y1 = int(y1)
if bbv_label_only:
if caption:
im_array = bbv.draw_flag_with_label(np.array(pil_im),
label = caption,
bbox = [x0,y0,x1,y1],
line_color = ImageColor.getrgb(color),
text_bg_color = ImageColor.getrgb(color)
)
else:
raise ValueError(f'im_draw_bbox: bbv_label_only is True but caption is None')
else:
im_array = bbv.draw_rectangle(np.array(pil_im),
bbox = [x0, y0, x1, y1],
bbox_color = ImageColor.getrgb(color),
thickness = width
)
im_array = bbv.add_label(
im_array, label = caption,
bbox = [x0,y0,x1,y1],
text_bg_color = ImageColor.getrgb(color)
)if caption else im_array
return Image.fromarray(im_array)
### Streamlit App ###
def mod_DPT(pil_im, model_def):
depth_im = DPT.inference(img_array_rgb = np.array(pil_im), model_def = model_def)
return depth_im
def Main(model_dict):
st.set_page_config(layout = 'wide')
l_col, r_col = st.columns(2)
show_miro_logo(st_asset = l_col)
with l_col.expander('Monocular Depth: CNN vs Transformers'):
st.info(f'''
Comparsion of two models: [BTS (CNN)](https://github.com/ErenBalatkan/Bts-PyTorch)
and [DPT (Transformer)](https://huggingface.co/Intel/dpt-large)
''')
im = get_image(st_asset = r_col.expander('Input Image', expanded = True), extension_list = ['jpg','jpeg'])
model_name = r_col.selectbox('Pick Model', options = ['DPT','BTS'])
if im:
model_def = DPT.load_model()
d_im = mod_DPT(pil_im = im, model_def=model_def)
l_col, r_col = st.columns(2)
l_col.image(im, caption = 'Input Image')
r_col.image(saliency_im, caption = 'Depth Map')
else:
st.warning(f'please provide an image :point_up:')
if __name__ == '__main__':
model_dict = load_model()
Main(model_dict = model_dict)