MonocularDepth / BTS_infer.py
ohjho
tested BTS model and added to the app
b240372
raw history blame
No virus
3.53 kB
import BTS, cv2, torch, gdown, os, zipfile
import numpy as np
from PIL import Image
def download_model_weight(model_dir,
file_key_dict = {'bts_latest':"1_mENn0G9YlLAAr3N8DVDt4Hk2SBbo1pl"}):
if not os.path.isdir(model_dir):
print(f'--- making model directory: {model_dir}')
os.makedirs(model_dir)
fname = list(file_key_dict.keys())[0]
key = file_key_dict[fname]
url = f'https://drive.google.com/uc?id={key}&export=download'
tmp_zip_fp = os.path.join(model_dir, fname)
print(f'--- downloading model weights from {url}')
gdown.download(url, tmp_zip_fp, quiet = True)
# with zipfile.ZipFile(tmp_zip_fp, "r") as zip_ref:
# for file in zip_ref.namelist():
# zip_ref.extract(file, model_dir)
# os.remove(tmp_zip_fp)
print(f"--- downloaded model weights to {tmp_zip_fp}", flush=True)
def get_model(model_path = './models/bts_latest'):
if not os.path.isfile(model_path):
download_model_weight(model_dir = os.path.dirname(model_path))
model = BTS.BtsController()
model.load_model(model_path)
model.eval()
return model
def im_max_long_edge(im_np_array, size = 1080, return_pil_im = False,
resample_algo = Image.LANCZOS, debug = False):
''' Return an image whose long edge is no longer than the given size
Args:
resample_algo: default to LANCZOS b/c it gives best downscaling quality (per https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table)
'''
org_h, org_w, _ = im_np_array.shape
out_im = None
if debug:
print(f'im_max_long_edge: seeing input w,h of {(org_w, org_h)}')
if max(org_h, org_w) <= size:
out_im = im_np_array
if debug:
print(f'im_max_long_edge: image dim is smaller than max {size}. no resizing required.')
else:
wh_ratio = org_w / org_h
if org_h > org_w:
# fix h to size
h = size
w = h * wh_ratio
else:
# fix w to size
w = size
h = w / wh_ratio
w = int(w)
h = int(h)
pil_im = Image.fromarray(im_np_array).resize((w,h), resample = resample_algo)
out_im = np.array(pil_im)
if debug:
print(f'im_max_long_edge: resizing image to w,h of {(w,h)}')
return Image.fromarray(out_im) if return_pil_im else out_im
def format_depth_map(depth_map, debug = True):
dmax = depth_map.max()
dmin = depth_map.min()
print(f'depth map origin min-max: ({dmin}, {dmax})')
# formatted = ((depth_map /dmax)* 255).astype('uint8')
# min-max normalization
formatted = (depth_map - depth_map.min())/(depth_map.max()-depth_map.min())
return (formatted * 255).astype('uint8')
def inference(img_array_rgb, model_obj, as_pil = False):
h, w, _ = img_array_rgb.shape
img_array_rgb = im_max_long_edge(img_array_rgb,return_pil_im=False, size=720)
prediction = model_obj.predict(img_array_rgb, is_channels_first = False, normalize = True)
visual_depth_map = model_obj.depth_map_to_grayimg(prediction)
visual_depth_map = format_depth_map(visual_depth_map)
visual_depth_map = Image.fromarray(visual_depth_map).resize((w,h),resample = Image.LANCZOS)
return visual_depth_map if as_pil else np.array(visual_depth_map)
# prediction = torch.nn.functional.interpolate(
# prediction.unsqueeze(1),
# size=img_array_rgb.shape[:2],
# mode="bicubic",
# align_corners=False,
# ).squeeze()
#
# output = prediction.cpu().numpy()
# formatted = (output * 255 / np.max(output)).astype('uint8')
# img = Image.fromarray(formatted)
# return img