import torch import torch.nn.functional as F import logging import os import os.path as osp #os.system('nvidia-smi') import cupy import sys CODE_SPACE=os.path.dirname(os.path.dirname(os.path.abspath(__file__))) try: from mmcv.utils import Config, DictAction except: from mmengine import Config, DictAction from mono.utils.logger import setup_logger import glob from mono.utils.comm import init_env from mono.model.monodepth_model import get_configured_monodepth_model from mono.utils.running import load_ckpt from mono.utils.do_test import transform_test_data_scalecano, get_prediction from mono.utils.custom_data import load_from_annos, load_data from mono.utils.avg_meter import MetricAverageMeter from mono.utils.visualization import save_val_imgs, create_html, save_raw_imgs, save_normal_val_imgs import cv2 from tqdm import tqdm import numpy as np from PIL import Image, ExifTags import matplotlib.pyplot as plt from mono.utils.unproj_pcd import reconstruct_pcd, save_point_cloud, ply_to_obj from mono.utils.transform import gray_to_colormap from mono.utils.visualization import vis_surface_normal import gradio as gr import plotly.graph_objects as go #torch.hub.download_url_to_file('https://images.unsplash.com/photo-1437622368342-7a3d73a34c8f', 'turtle.jpg') #torch.hub.download_url_to_file('https://images.unsplash.com/photo-1519066629447-267fffa62d4b', 'lions.jpg') cfg_large = Config.fromfile('./mono/configs/HourglassDecoder/vit.raft5.large.py') model_large = get_configured_monodepth_model(cfg_large, ) model_large, _, _, _ = load_ckpt('./weight/metric_depth_vit_large_800k.pth', model_large, strict_match=False) model_large.eval() cfg_small = Config.fromfile('./mono/configs/HourglassDecoder/vit.raft5.small.py') model_small = get_configured_monodepth_model(cfg_small, ) model_small, _, _, _ = load_ckpt('./weight/metric_depth_vit_small_800k.pth', model_small, strict_match=False) model_small.eval() device = "cuda" model_large.to(device) model_small.to(device) def predict_depth_normal(img, model_selection="vit-small", fx=1000.0, fy=1000.0, state_cache={}): if model_selection == "vit-small": model = model_small cfg = cfg_small elif model_selection == "vit-large": model = model_large cfg = cfg_large else: return None, None, None, None, state_cache, "Not implemented model." if img is None: return None, None, None, None, state_cache, "Please upload an image and wait for the upload to complete." cv_image = np.array(img) img = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB) intrinsic = [fx, fy, img.shape[1]/2, img.shape[0]/2] rgb_input, cam_models_stacks, pad, label_scale_factor = transform_test_data_scalecano(img, intrinsic, cfg.data_basic) with torch.no_grad(): pred_depth, pred_depth_scale, scale, output, confidence = get_prediction( model = model, input = rgb_input, cam_model = cam_models_stacks, pad_info = pad, scale_info = label_scale_factor, gt_depth = None, normalize_scale = cfg.data_basic.depth_range[1], ori_shape=[img.shape[0], img.shape[1]], ) pred_normal = output['normal_out_list'][0][:, :3, :, :] H, W = pred_normal.shape[2:] pred_normal = pred_normal[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3]] pred_depth = pred_depth.squeeze().cpu().numpy() pred_depth[pred_depth<0] = 0 pred_color = gray_to_colormap(pred_depth) pred_normal = torch.nn.functional.interpolate(pred_normal, [img.shape[0], img.shape[1]], mode='bilinear').squeeze() pred_normal = pred_normal.permute(1,2,0) pred_color_normal = vis_surface_normal(pred_normal) pred_normal = pred_normal.cpu().numpy() # Storing depth and normal map in state for potential 3D reconstruction state_cache['depth'] = pred_depth state_cache['normal'] = pred_normal state_cache['img'] = img state_cache['intrinsic'] = intrinsic state_cache['confidence'] = confidence # save depth and normal map to .npy file if 'save_dir' not in state_cache: cache_id = np.random.randint(0, 100000000000) while osp.exists(f'recon_cache/{cache_id:08d}'): cache_id = np.random.randint(0, 100000000000) state_cache['save_dir'] = f'recon_cache/{cache_id:08d}' os.makedirs(state_cache['save_dir'], exist_ok=True) depth_file = f"{state_cache['save_dir']}/depth.npy" normal_file = f"{state_cache['save_dir']}/normal.npy" np.save(depth_file, pred_depth) np.save(normal_file, pred_normal) ##formatted = (output * 255 / np.max(output)).astype('uint8') img = Image.fromarray(pred_color) img_normal = Image.fromarray(pred_color_normal) return img, depth_file, img_normal, normal_file, state_cache, "Success!" def get_camera(img): if img is None: return None, None, None, "Please upload an image and wait for the upload to complete." try: exif = img.getexif() exif.update(exif.get_ifd(ExifTags.IFD.Exif)) except: exif = {} sensor_width = exif.get(ExifTags.Base.FocalPlaneYResolution, None) sensor_height = exif.get(ExifTags.Base.FocalPlaneXResolution, None) focal_length = exif.get(ExifTags.Base.FocalLength, None) # convert sensor size to mm, see https://photo.stackexchange.com/questions/40865/how-can-i-get-the-image-sensor-dimensions-in-mm-to-get-circle-of-confusion-from w, h = img.size sensor_width = w / sensor_width * 25.4 if sensor_width is not None else None sensor_height = h / sensor_height * 25.4 if sensor_height is not None else None focal_length = focal_length * 1.0 if focal_length is not None else None message = "Success!" if focal_length is None: message = "Focal length not found in EXIF. Please manually input." elif sensor_width is None and sensor_height is None: sensor_width = 16 sensor_height = h / w * sensor_width message = f"Sensor size not found in EXIF. Using {sensor_width}x{sensor_height:.2f} mm as default." return sensor_width, sensor_height, focal_length, message def get_intrinsic(img, sensor_width, sensor_height, focal_length): if img is None: return None, None, "Please upload an image and wait for the upload to complete." if sensor_width is None or sensor_height is None or focal_length is None: return 1000, 1000, "Insufficient information. Try detecting camera first or use default 1000 for fx and fy." if sensor_width == 0 or sensor_height == 0 or focal_length == 0: return 1000, 1000, "Insufficient information. Try detecting camera first or use default 1000 for fx and fy." # calculate focal length in pixels w, h = img.size fx = w / sensor_width * focal_length if sensor_width is not None else None fy = h / sensor_height * focal_length if sensor_height is not None else None # if fx is None: # return fy, fy, "Sensor width not provided, using fy for both fx and fy" # if fy is None: # return fx, fx, "Sensor height not provided, using fx for both fx and fy" return fx, fy, "Success!" def unprojection_pcd(state_cache): depth_map = state_cache.get('depth', None) normal_map = state_cache.get('normal', None) img = state_cache.get('img', None) intrinsic = state_cache.get('intrinsic', None) if depth_map is None or img is None: return None, "Please predict depth and normal first." # # downsample/upsample the depth map to confidence map size # confidence = state_cache.get('confidence', None) # if confidence is not None: # H, W = confidence.shape # # intrinsic[0] *= W / depth_map.shape[1] # # intrinsic[1] *= H / depth_map.shape[0] # # intrinsic[2] *= W / depth_map.shape[1] # # intrinsic[3] *= H / depth_map.shape[0] # depth_map = cv2.resize(depth_map, (W, H), interpolation=cv2.INTER_LINEAR) # img = cv2.resize(img, (W, H), interpolation=cv2.INTER_LINEAR) # # filter out depth map by confidence # mask = confidence.cpu().numpy() > 0 # downsample the depth map if too large if depth_map.shape[0] > 1080: scale = 1080 / depth_map.shape[0] depth_map = cv2.resize(depth_map, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) img = cv2.resize(img, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) intrinsic = [intrinsic[0]*scale, intrinsic[1]*scale, intrinsic[2]*scale, intrinsic[3]*scale] if 'save_dir' not in state_cache: cache_id = np.random.randint(0, 100000000000) while osp.exists(f'recon_cache/{cache_id:08d}'): cache_id = np.random.randint(0, 100000000000) state_cache['save_dir'] = f'recon_cache/{cache_id:08d}' os.makedirs(state_cache['save_dir'], exist_ok=True) pcd_ply = f"{state_cache['save_dir']}/output.ply" pcd_obj = pcd_ply.replace(".ply", ".obj") pcd = reconstruct_pcd(depth_map, intrinsic[0], intrinsic[1], intrinsic[2], intrinsic[3]) # if mask is not None: # pcd_filtered = pcd[mask] # img_filtered = img[mask] pcd_filtered = pcd.reshape(-1, 3) img_filtered = img.reshape(-1, 3) save_point_cloud(pcd_filtered, img_filtered, pcd_ply, binary=False) # ply_to_obj(pcd_ply, pcd_obj) # downsample the point cloud for visualization num_samples = 250000 if pcd_filtered.shape[0] > num_samples: indices = np.random.choice(pcd_filtered.shape[0], num_samples, replace=False) pcd_downsampled = pcd_filtered[indices] img_downsampled = img_filtered[indices] else: pcd_downsampled = pcd_filtered img_downsampled = img_filtered # plotly show color_str = np.array([f"rgb({r},{g},{b})" for b,g,r in img_downsampled]) data=[go.Scatter3d( x=pcd_downsampled[:,0], y=pcd_downsampled[:,1], z=pcd_downsampled[:,2], mode='markers', marker=dict( size=1, color=color_str, opacity=0.8, ) )] layout = go.Layout( margin=dict(l=0, r=0, b=0, t=0), scene=dict( camera = dict( eye=dict(x=0, y=0, z=-1), up=dict(x=0, y=-1, z=0) ), xaxis=dict(showgrid=False, showticklabels=False, visible=False), yaxis=dict(showgrid=False, showticklabels=False, visible=False), zaxis=dict(showgrid=False, showticklabels=False, visible=False), ) ) fig = go.Figure(data=data, layout=layout) return fig, pcd_ply, "Success!" title = "Metric3D" description = '''# Metric3Dv2: A versatile monocular geometric foundation model for zero-shot metric depth and surface normal estimation Gradio demo for Metric3D v1/v2 which takes in a single image for computing metric depth and surface normal. To use it, simply upload your image, or click one of the examples to load them. Learn more from our paper linked below.''' article = "

Metric3D arxiv | Metric3Dv2 arxiv | Github Repo

" custom_css = '''#button1, #button2 { width: 20px; }''' examples = [ #["turtle.jpg"], #["lions.jpg"] #["files/gundam.jpg"], "files/p50_pro.jpg", "files/iphone13.JPG", "files/canon_cat.JPG", "files/canon_dog.JPG", "files/museum.jpg", "files/terra.jpg", "files/underwater.jpg", "files/venue.jpg", ] with gr.Blocks(title=title, css=custom_css) as demo: gr.Markdown(description + article) # input and control components with gr.Row(): with gr.Column(): image_input = gr.Image(type='pil', label="Original Image") _ = gr.Examples(examples=examples, inputs=[image_input]) with gr.Column(): model_dropdown = gr.Dropdown(["vit-small", "vit-large"], label="Model", value="vit-large") with gr.Accordion('Advanced options (beta)', open=True): with gr.Row(): sensor_width = gr.Number(None, label="Sensor Width in mm", precision=2) sensor_height = gr.Number(None, label="Sensor Height in mm", precision=2) focal_len = gr.Number(None, label="Focal Length in mm", precision=2) camera_detector = gr.Button("Detect Camera from EXIF", elem_id="#button1") with gr.Row(): fx = gr.Number(1000.0, label="fx in pixels", precision=2) fy = gr.Number(1000.0, label="fy in pixels", precision=2) focal_detector = gr.Button("Calculate Intrinsic", elem_id="#button2") message_box = gr.Textbox(label="Messages") # depth and normal submit_button = gr.Button("Predict Depth and Normal") with gr.Row(): with gr.Column(): depth_output = gr.Image(label="Output Depth") depth_file = gr.File(label="Depth (.npy)") with gr.Column(): normal_output = gr.Image(label="Output Normal") normal_file = gr.File(label="Normal (.npy)") # 3D reconstruction reconstruct_button = gr.Button("Reconstruct 3D") pcd_output = gr.Plot(label="3D Point Cloud (Sampled sparse version)") pcd_ply = gr.File(label="3D Point Cloud (.ply)") # cache for depth, normal maps and other states state_cache = gr.State({}) # detect focal length in pixels camera_detector.click(fn=get_camera, inputs=[image_input], outputs=[sensor_width, sensor_height, focal_len, message_box]) focal_detector.click(fn=get_intrinsic, inputs=[image_input, sensor_width, sensor_height, focal_len], outputs=[fx, fy, message_box]) submit_button.click(fn=predict_depth_normal, inputs=[image_input, model_dropdown, fx, fy, state_cache], outputs=[depth_output, depth_file, normal_output, normal_file, state_cache, message_box]) reconstruct_button.click(fn=unprojection_pcd, inputs=[state_cache], outputs=[pcd_output, pcd_ply, message_box]) demo.launch() # iface = gr.Interface( # depth_normal, # inputs=[ # gr.Image(type='pil', label="Original Image"), # gr.Dropdown(["vit-small", "vit-large"], label="Model", info="Select a model type", value="vit-large") # ], # outputs=[ # gr.Image(type="pil", label="Output Depth"), # gr.Image(type="pil", label="Output Normal"), # gr.Textbox(label="Messages") # ], # title=title, # description=description, # article=article, # examples=examples, # analytics_enabled=False # ) # iface.launch()