import BTS, cv2, torch, gdown, os, zipfile import numpy as np from PIL import Image def download_model_weight(model_dir, file_key_dict = {'bts_latest':"1_mENn0G9YlLAAr3N8DVDt4Hk2SBbo1pl"}): if not os.path.isdir(model_dir): print(f'--- making model directory: {model_dir}') os.makedirs(model_dir) fname = list(file_key_dict.keys())[0] key = file_key_dict[fname] url = f'https://drive.google.com/uc?id={key}&export=download' tmp_zip_fp = os.path.join(model_dir, fname) print(f'--- downloading model weights from {url}') gdown.download(url, tmp_zip_fp, quiet = True) # with zipfile.ZipFile(tmp_zip_fp, "r") as zip_ref: # for file in zip_ref.namelist(): # zip_ref.extract(file, model_dir) # os.remove(tmp_zip_fp) print(f"--- downloaded model weights to {tmp_zip_fp}", flush=True) def get_model(model_path = './models/bts_latest'): if not os.path.isfile(model_path): download_model_weight(model_dir = os.path.dirname(model_path)) model = BTS.BtsController() model.load_model(model_path) model.eval() return model def im_max_long_edge(im_np_array, size = 1080, return_pil_im = False, resample_algo = Image.LANCZOS, debug = False): ''' Return an image whose long edge is no longer than the given size Args: resample_algo: default to LANCZOS b/c it gives best downscaling quality (per https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table) ''' org_h, org_w, _ = im_np_array.shape out_im = None if debug: print(f'im_max_long_edge: seeing input w,h of {(org_w, org_h)}') if max(org_h, org_w) <= size: out_im = im_np_array if debug: print(f'im_max_long_edge: image dim is smaller than max {size}. no resizing required.') else: wh_ratio = org_w / org_h if org_h > org_w: # fix h to size h = size w = h * wh_ratio else: # fix w to size w = size h = w / wh_ratio w = int(w) h = int(h) pil_im = Image.fromarray(im_np_array).resize((w,h), resample = resample_algo) out_im = np.array(pil_im) if debug: print(f'im_max_long_edge: resizing image to w,h of {(w,h)}') return Image.fromarray(out_im) if return_pil_im else out_im def format_depth_map(depth_map, debug = True): dmax = depth_map.max() dmin = depth_map.min() print(f'depth map origin min-max: ({dmin}, {dmax})') # formatted = ((depth_map /dmax)* 255).astype('uint8') # min-max normalization formatted = (depth_map - depth_map.min())/(depth_map.max()-depth_map.min()) return (formatted * 255).astype('uint8') def inference(img_array_rgb, model_obj, as_pil = False): h, w, _ = img_array_rgb.shape img_array_rgb = im_max_long_edge(img_array_rgb,return_pil_im=False, size=720) prediction = model_obj.predict(img_array_rgb, is_channels_first = False, normalize = True) visual_depth_map = model_obj.depth_map_to_grayimg(prediction) visual_depth_map = format_depth_map(visual_depth_map) visual_depth_map = Image.fromarray(visual_depth_map).resize((w,h),resample = Image.LANCZOS) return visual_depth_map if as_pil else np.array(visual_depth_map) # prediction = torch.nn.functional.interpolate( # prediction.unsqueeze(1), # size=img_array_rgb.shape[:2], # mode="bicubic", # align_corners=False, # ).squeeze() # # output = prediction.cpu().numpy() # formatted = (output * 255 / np.max(output)).astype('uint8') # img = Image.fromarray(formatted) # return img