Spaces:
Running
Running
# standard library | |
from pathlib import Path | |
from typing import * | |
import sys, os | |
# third party | |
import cv2 | |
import numpy as np | |
import torch | |
from PIL import Image | |
from mmengine import Config | |
# metric 3d | |
metric3d_path = Path(__file__).resolve().parent | |
metric3d_mono_path = metric3d_path / 'mono' | |
sys.path.append(str(metric3d_path)) | |
sys.path.append(str(metric3d_mono_path)) | |
from mono.model.monodepth_model import get_configured_monodepth_model | |
from mono.utils.running import load_ckpt | |
from mono.utils.do_test import transform_test_data_scalecano, get_prediction | |
from mono.utils.mldb import load_data_info, reset_ckpt_path | |
from mono.utils.transform import gray_to_colormap | |
__ALL__ = ['Metric3D'] | |
def calculate_radius(mask): | |
# 获取矩阵的大小 N | |
N = mask.shape[0] | |
# 找到矩阵的中心点 | |
center = (N - 1) / 2 | |
# 获取所有值为0的点的坐标 | |
y, x = np.where(mask == 0) | |
# 计算这些点到中心的距离 | |
distances = np.sqrt((x - center) ** 2 + (y - center) ** 2) | |
# 半径是这些距离的最大值 | |
radius = distances.max() | |
return radius | |
class Metric3D: | |
cfg_: Config | |
model_: torch.nn.Module | |
def __init__( | |
self, | |
checkpoint: Union[str, Path] = './weights/metric_depth_vit_large_800k.pth', | |
model_name: str = 'v2-L', | |
) -> None: | |
checkpoint = Path(checkpoint).resolve() | |
cfg:Config = self._load_config_(model_name, checkpoint) | |
# build model | |
model = get_configured_monodepth_model(cfg, ) | |
model = torch.nn.DataParallel(model).cuda() | |
model, _, _, _ = load_ckpt(cfg.load_from, model, strict_match=False) | |
model.eval() | |
# save to self | |
self.cfg_ = cfg | |
self.model_ = model | |
def __call__( | |
self, | |
rgb_image: Union[np.ndarray, Image.Image, str, Path], | |
intrinsic: Union[str, Path, np.ndarray], | |
d_max: Optional[float] = 300, | |
d_min: Optional[float] = 0, | |
margin_mask=None, | |
crop_margin=0 | |
) -> np.ndarray: | |
# read image | |
if isinstance(rgb_image, (str, Path)): | |
rgb_image = np.array(Image.open(rgb_image)) | |
elif isinstance(rgb_image, Image.Image): | |
rgb_image = np.array(rgb_image) | |
if isinstance(intrinsic, (str, Path)): | |
intrinsic = np.loadtxt(intrinsic) | |
intrinsic = intrinsic[:4] | |
# crop margin mask | |
if crop_margin != 0: | |
original_h, original_w = margin_mask.shape | |
# radius = calculate_radius(margin_mask) - crop_margin | |
radius = original_h // 2 - crop_margin | |
left, right, up, bottom = int(original_w//2-radius), int(original_w//2+radius), int(original_h//2-radius), int(original_h//2+radius) | |
rgb_image = rgb_image[up:bottom, left:right] | |
h, w = rgb_image.shape[:2] | |
intrinsic[2] = w/2 | |
intrinsic[3] = h/2 | |
cv2.imwrite("debug.png", rgb_image[:, :, ::-1]) | |
# get intrinsic | |
h, w = rgb_image.shape[:2] | |
input_size = (616, 1064) | |
scale = min(input_size[0] / h, input_size[1] / w) | |
# transform image | |
rgb_input, cam_models_stacks, pad, label_scale_factor = \ | |
transform_test_data_scalecano(rgb_image, intrinsic, self.cfg_.data_basic) | |
# predict depth | |
normalize_scale = self.cfg_.data_basic.depth_range[1] | |
rgb_input = rgb_input.unsqueeze(0) | |
pred_depth, output = get_prediction( | |
model = self.model_, | |
input = rgb_input, | |
cam_model = cam_models_stacks, | |
pad_info = pad, | |
scale_info = label_scale_factor, | |
gt_depth = None, | |
normalize_scale = normalize_scale, | |
ori_shape=[h, w], | |
) | |
# post process | |
# pred_depth = (pred_depth > 0) * (pred_depth < 300) * pred_depth | |
pred_depth = pred_depth.squeeze().cpu().numpy() | |
pred_depth[pred_depth > d_max] = 0 | |
pred_depth[pred_depth < d_min] = 0 | |
pred_depth = pred_depth[pad[0] : pred_depth.shape[0] - pad[1], pad[2] : pred_depth.shape[1] - pad[3]] | |
canonical_to_real_scale = intrinsic[0] * scale / 1000.0 # 1000.0 is the focal length of canonical camera | |
pred_depth = pred_depth * canonical_to_real_scale # now the depth is metric | |
# because crop margin at beginning | |
if not margin_mask is None: | |
final_depth = np.zeros((original_h, original_w)) | |
pred_depth = cv2.resize(pred_depth, (w, h)) | |
final_depth[up:bottom, left:right] = pred_depth | |
return pred_depth | |
def _load_config_( | |
self, | |
model_name: str, | |
checkpoint: Union[str, Path], | |
) -> Config: | |
print(f'Loading model {model_name} from {checkpoint}') | |
config_path = metric3d_path / 'mono/configs/HourglassDecoder' | |
assert model_name in ['v2-L', 'v2-S', 'v2-g'], f"Model {model_name} not supported" | |
# load config file | |
cfg = Config.fromfile( | |
str(config_path / 'vit.raft5.large.py') if model_name == 'v2-L' | |
else str(config_path / 'vit.raft5.small.py') if model_name == 'v2-S' | |
else str(config_path / 'vit.raft5.giant2.py') | |
) | |
cfg.load_from = str(checkpoint) | |
# load data info | |
data_info = {} | |
load_data_info('data_info', data_info=data_info) | |
cfg.mldb_info = data_info | |
# update check point info | |
reset_ckpt_path(cfg.model, data_info) | |
# set distributed | |
cfg.distributed = False | |
return cfg | |
def gray_to_colormap(depth: np.ndarray) -> np.ndarray: | |
return gray_to_colormap(depth) |