Spaces:
Build error
Build error
| import BTS, cv2, torch, gdown, os, zipfile | |
| import numpy as np | |
| from PIL import Image | |
| def download_model_weight(model_dir, | |
| file_key_dict = {'bts_latest':"1_mENn0G9YlLAAr3N8DVDt4Hk2SBbo1pl"}): | |
| if not os.path.isdir(model_dir): | |
| print(f'--- making model directory: {model_dir}') | |
| os.makedirs(model_dir) | |
| fname = list(file_key_dict.keys())[0] | |
| key = file_key_dict[fname] | |
| url = f'https://drive.google.com/uc?id={key}&export=download' | |
| tmp_zip_fp = os.path.join(model_dir, fname) | |
| print(f'--- downloading model weights from {url}') | |
| gdown.download(url, tmp_zip_fp, quiet = True) | |
| # with zipfile.ZipFile(tmp_zip_fp, "r") as zip_ref: | |
| # for file in zip_ref.namelist(): | |
| # zip_ref.extract(file, model_dir) | |
| # os.remove(tmp_zip_fp) | |
| print(f"--- downloaded model weights to {tmp_zip_fp}", flush=True) | |
| def get_model(model_path = './models/bts_latest'): | |
| if not os.path.isfile(model_path): | |
| download_model_weight(model_dir = os.path.dirname(model_path)) | |
| model = BTS.BtsController() | |
| model.load_model(model_path) | |
| model.eval() | |
| return model | |
| def im_max_long_edge(im_np_array, size = 1080, return_pil_im = False, | |
| resample_algo = Image.LANCZOS, debug = False): | |
| ''' Return an image whose long edge is no longer than the given size | |
| Args: | |
| resample_algo: default to LANCZOS b/c it gives best downscaling quality (per https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table) | |
| ''' | |
| org_h, org_w, _ = im_np_array.shape | |
| out_im = None | |
| if debug: | |
| print(f'im_max_long_edge: seeing input w,h of {(org_w, org_h)}') | |
| if max(org_h, org_w) <= size: | |
| out_im = im_np_array | |
| if debug: | |
| print(f'im_max_long_edge: image dim is smaller than max {size}. no resizing required.') | |
| else: | |
| wh_ratio = org_w / org_h | |
| if org_h > org_w: | |
| # fix h to size | |
| h = size | |
| w = h * wh_ratio | |
| else: | |
| # fix w to size | |
| w = size | |
| h = w / wh_ratio | |
| w = int(w) | |
| h = int(h) | |
| pil_im = Image.fromarray(im_np_array).resize((w,h), resample = resample_algo) | |
| out_im = np.array(pil_im) | |
| if debug: | |
| print(f'im_max_long_edge: resizing image to w,h of {(w,h)}') | |
| return Image.fromarray(out_im) if return_pil_im else out_im | |
| def format_depth_map(depth_map, debug = True): | |
| dmax = depth_map.max() | |
| dmin = depth_map.min() | |
| print(f'depth map origin min-max: ({dmin}, {dmax})') | |
| # formatted = ((depth_map /dmax)* 255).astype('uint8') | |
| # min-max normalization | |
| formatted = (depth_map - depth_map.min())/(depth_map.max()-depth_map.min()) | |
| return (formatted * 255).astype('uint8') | |
| def inference(img_array_rgb, model_obj, as_pil = False): | |
| h, w, _ = img_array_rgb.shape | |
| img_array_rgb = im_max_long_edge(img_array_rgb,return_pil_im=False, size=720) | |
| prediction = model_obj.predict(img_array_rgb, is_channels_first = False, normalize = True) | |
| visual_depth_map = model_obj.depth_map_to_grayimg(prediction) | |
| visual_depth_map = format_depth_map(visual_depth_map) | |
| visual_depth_map = Image.fromarray(visual_depth_map).resize((w,h),resample = Image.LANCZOS) | |
| return visual_depth_map if as_pil else np.array(visual_depth_map) | |
| # prediction = torch.nn.functional.interpolate( | |
| # prediction.unsqueeze(1), | |
| # size=img_array_rgb.shape[:2], | |
| # mode="bicubic", | |
| # align_corners=False, | |
| # ).squeeze() | |
| # | |
| # output = prediction.cpu().numpy() | |
| # formatted = (output * 255 / np.max(output)).astype('uint8') | |
| # img = Image.fromarray(formatted) | |
| # return img | |