File size: 3,525 Bytes
b369bda
 
b240372
b369bda
b240372
 
b369bda
 
 
b240372
 
b369bda
b240372
b369bda
 
 
 
b240372
 
 
 
 
b369bda
 
 
 
 
 
 
 
 
b240372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b369bda
b240372
 
 
 
b369bda
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import BTS, cv2, torch, gdown, os, zipfile
import numpy as np
from PIL import Image

def download_model_weight(model_dir,
    file_key_dict = {'bts_latest':"1_mENn0G9YlLAAr3N8DVDt4Hk2SBbo1pl"}):
    if not os.path.isdir(model_dir):
        print(f'--- making model directory: {model_dir}')
        os.makedirs(model_dir)
    fname = list(file_key_dict.keys())[0]
    key = file_key_dict[fname]
    url = f'https://drive.google.com/uc?id={key}&export=download'
    tmp_zip_fp = os.path.join(model_dir, fname)

    print(f'--- downloading model weights from {url}')
    gdown.download(url, tmp_zip_fp, quiet = True)

    # with zipfile.ZipFile(tmp_zip_fp, "r") as zip_ref:
    #     for file in zip_ref.namelist():
    #         zip_ref.extract(file, model_dir)
    # os.remove(tmp_zip_fp)
    print(f"--- downloaded model weights to {tmp_zip_fp}", flush=True)

def get_model(model_path = './models/bts_latest'):
    if not os.path.isfile(model_path):
        download_model_weight(model_dir = os.path.dirname(model_path))
    model = BTS.BtsController()
    model.load_model(model_path)
    model.eval()
    return model

def im_max_long_edge(im_np_array, size = 1080, return_pil_im = False,
		resample_algo = Image.LANCZOS, debug = False):
	''' Return an image whose long edge is no longer than the given size
	Args:
		resample_algo: default to LANCZOS b/c it gives best downscaling quality (per https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table)
	'''
	org_h, org_w, _ = im_np_array.shape
	out_im = None
	if debug:
		print(f'im_max_long_edge: seeing input w,h of {(org_w, org_h)}')

	if max(org_h, org_w) <= size:
		out_im = im_np_array
		if debug:
			print(f'im_max_long_edge: image dim is smaller than max {size}. no resizing required.')
	else:
		wh_ratio = org_w / org_h
		if org_h > org_w:
			# fix h to size
			h = size
			w = h * wh_ratio
		else:
            # fix w to size
			w = size
			h = w / wh_ratio
		w = int(w)
		h = int(h)
		pil_im = Image.fromarray(im_np_array).resize((w,h), resample = resample_algo)
		out_im = np.array(pil_im)

		if debug:
			print(f'im_max_long_edge: resizing image to w,h of {(w,h)}')
	return Image.fromarray(out_im) if return_pil_im else out_im

def format_depth_map(depth_map, debug = True):
    dmax = depth_map.max()
    dmin = depth_map.min()
    print(f'depth map origin min-max: ({dmin}, {dmax})')
    # formatted = ((depth_map /dmax)* 255).astype('uint8')

    # min-max normalization
    formatted = (depth_map - depth_map.min())/(depth_map.max()-depth_map.min())
    return (formatted * 255).astype('uint8')

def inference(img_array_rgb, model_obj, as_pil = False):
    h, w, _ = img_array_rgb.shape
    img_array_rgb = im_max_long_edge(img_array_rgb,return_pil_im=False, size=720)
    prediction = model_obj.predict(img_array_rgb, is_channels_first = False, normalize = True)
    visual_depth_map = model_obj.depth_map_to_grayimg(prediction)
    visual_depth_map = format_depth_map(visual_depth_map)
    visual_depth_map = Image.fromarray(visual_depth_map).resize((w,h),resample = Image.LANCZOS)
    return visual_depth_map if as_pil else np.array(visual_depth_map)

    # prediction = torch.nn.functional.interpolate(
    #     prediction.unsqueeze(1),
    #     size=img_array_rgb.shape[:2],
    #     mode="bicubic",
    #     align_corners=False,
    # ).squeeze()
    #
    # output = prediction.cpu().numpy()
    # formatted = (output * 255 / np.max(output)).astype('uint8')
    # img = Image.fromarray(formatted)
    # return img