Spaces:
Build error
Build error
import BTS, cv2, torch, gdown, os, zipfile | |
import numpy as np | |
from PIL import Image | |
def download_model_weight(model_dir, | |
file_key_dict = {'bts_latest':"1_mENn0G9YlLAAr3N8DVDt4Hk2SBbo1pl"}): | |
if not os.path.isdir(model_dir): | |
print(f'--- making model directory: {model_dir}') | |
os.makedirs(model_dir) | |
fname = list(file_key_dict.keys())[0] | |
key = file_key_dict[fname] | |
url = f'https://drive.google.com/uc?id={key}&export=download' | |
tmp_zip_fp = os.path.join(model_dir, fname) | |
print(f'--- downloading model weights from {url}') | |
gdown.download(url, tmp_zip_fp, quiet = True) | |
# with zipfile.ZipFile(tmp_zip_fp, "r") as zip_ref: | |
# for file in zip_ref.namelist(): | |
# zip_ref.extract(file, model_dir) | |
# os.remove(tmp_zip_fp) | |
print(f"--- downloaded model weights to {tmp_zip_fp}", flush=True) | |
def get_model(model_path = './models/bts_latest'): | |
if not os.path.isfile(model_path): | |
download_model_weight(model_dir = os.path.dirname(model_path)) | |
model = BTS.BtsController() | |
model.load_model(model_path) | |
model.eval() | |
return model | |
def im_max_long_edge(im_np_array, size = 1080, return_pil_im = False, | |
resample_algo = Image.LANCZOS, debug = False): | |
''' Return an image whose long edge is no longer than the given size | |
Args: | |
resample_algo: default to LANCZOS b/c it gives best downscaling quality (per https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table) | |
''' | |
org_h, org_w, _ = im_np_array.shape | |
out_im = None | |
if debug: | |
print(f'im_max_long_edge: seeing input w,h of {(org_w, org_h)}') | |
if max(org_h, org_w) <= size: | |
out_im = im_np_array | |
if debug: | |
print(f'im_max_long_edge: image dim is smaller than max {size}. no resizing required.') | |
else: | |
wh_ratio = org_w / org_h | |
if org_h > org_w: | |
# fix h to size | |
h = size | |
w = h * wh_ratio | |
else: | |
# fix w to size | |
w = size | |
h = w / wh_ratio | |
w = int(w) | |
h = int(h) | |
pil_im = Image.fromarray(im_np_array).resize((w,h), resample = resample_algo) | |
out_im = np.array(pil_im) | |
if debug: | |
print(f'im_max_long_edge: resizing image to w,h of {(w,h)}') | |
return Image.fromarray(out_im) if return_pil_im else out_im | |
def format_depth_map(depth_map, debug = True): | |
dmax = depth_map.max() | |
dmin = depth_map.min() | |
print(f'depth map origin min-max: ({dmin}, {dmax})') | |
# formatted = ((depth_map /dmax)* 255).astype('uint8') | |
# min-max normalization | |
formatted = (depth_map - depth_map.min())/(depth_map.max()-depth_map.min()) | |
return (formatted * 255).astype('uint8') | |
def inference(img_array_rgb, model_obj, as_pil = False): | |
h, w, _ = img_array_rgb.shape | |
img_array_rgb = im_max_long_edge(img_array_rgb,return_pil_im=False, size=720) | |
prediction = model_obj.predict(img_array_rgb, is_channels_first = False, normalize = True) | |
visual_depth_map = model_obj.depth_map_to_grayimg(prediction) | |
visual_depth_map = format_depth_map(visual_depth_map) | |
visual_depth_map = Image.fromarray(visual_depth_map).resize((w,h),resample = Image.LANCZOS) | |
return visual_depth_map if as_pil else np.array(visual_depth_map) | |
# prediction = torch.nn.functional.interpolate( | |
# prediction.unsqueeze(1), | |
# size=img_array_rgb.shape[:2], | |
# mode="bicubic", | |
# align_corners=False, | |
# ).squeeze() | |
# | |
# output = prediction.cpu().numpy() | |
# formatted = (output * 255 / np.max(output)).astype('uint8') | |
# img = Image.fromarray(formatted) | |
# return img | |