MonocularDepth / BTS_infer.py
ohjho
tested BTS model and added to the app
b240372
import BTS, cv2, torch, gdown, os, zipfile
import numpy as np
from PIL import Image
def download_model_weight(model_dir,
file_key_dict = {'bts_latest':"1_mENn0G9YlLAAr3N8DVDt4Hk2SBbo1pl"}):
if not os.path.isdir(model_dir):
print(f'--- making model directory: {model_dir}')
os.makedirs(model_dir)
fname = list(file_key_dict.keys())[0]
key = file_key_dict[fname]
url = f'https://drive.google.com/uc?id={key}&export=download'
tmp_zip_fp = os.path.join(model_dir, fname)
print(f'--- downloading model weights from {url}')
gdown.download(url, tmp_zip_fp, quiet = True)
# with zipfile.ZipFile(tmp_zip_fp, "r") as zip_ref:
# for file in zip_ref.namelist():
# zip_ref.extract(file, model_dir)
# os.remove(tmp_zip_fp)
print(f"--- downloaded model weights to {tmp_zip_fp}", flush=True)
def get_model(model_path = './models/bts_latest'):
if not os.path.isfile(model_path):
download_model_weight(model_dir = os.path.dirname(model_path))
model = BTS.BtsController()
model.load_model(model_path)
model.eval()
return model
def im_max_long_edge(im_np_array, size = 1080, return_pil_im = False,
resample_algo = Image.LANCZOS, debug = False):
''' Return an image whose long edge is no longer than the given size
Args:
resample_algo: default to LANCZOS b/c it gives best downscaling quality (per https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table)
'''
org_h, org_w, _ = im_np_array.shape
out_im = None
if debug:
print(f'im_max_long_edge: seeing input w,h of {(org_w, org_h)}')
if max(org_h, org_w) <= size:
out_im = im_np_array
if debug:
print(f'im_max_long_edge: image dim is smaller than max {size}. no resizing required.')
else:
wh_ratio = org_w / org_h
if org_h > org_w:
# fix h to size
h = size
w = h * wh_ratio
else:
# fix w to size
w = size
h = w / wh_ratio
w = int(w)
h = int(h)
pil_im = Image.fromarray(im_np_array).resize((w,h), resample = resample_algo)
out_im = np.array(pil_im)
if debug:
print(f'im_max_long_edge: resizing image to w,h of {(w,h)}')
return Image.fromarray(out_im) if return_pil_im else out_im
def format_depth_map(depth_map, debug = True):
dmax = depth_map.max()
dmin = depth_map.min()
print(f'depth map origin min-max: ({dmin}, {dmax})')
# formatted = ((depth_map /dmax)* 255).astype('uint8')
# min-max normalization
formatted = (depth_map - depth_map.min())/(depth_map.max()-depth_map.min())
return (formatted * 255).astype('uint8')
def inference(img_array_rgb, model_obj, as_pil = False):
h, w, _ = img_array_rgb.shape
img_array_rgb = im_max_long_edge(img_array_rgb,return_pil_im=False, size=720)
prediction = model_obj.predict(img_array_rgb, is_channels_first = False, normalize = True)
visual_depth_map = model_obj.depth_map_to_grayimg(prediction)
visual_depth_map = format_depth_map(visual_depth_map)
visual_depth_map = Image.fromarray(visual_depth_map).resize((w,h),resample = Image.LANCZOS)
return visual_depth_map if as_pil else np.array(visual_depth_map)
# prediction = torch.nn.functional.interpolate(
# prediction.unsqueeze(1),
# size=img_array_rgb.shape[:2],
# mode="bicubic",
# align_corners=False,
# ).squeeze()
#
# output = prediction.cpu().numpy()
# formatted = (output * 255 / np.max(output)).astype('uint8')
# img = Image.fromarray(formatted)
# return img