Spaces:
Running
Running
import gradio as gr | |
import torch | |
import cv2 | |
import numpy as np | |
import json | |
from unidepth.models import UniDepthV2 | |
import os | |
import matplotlib.pyplot as plt | |
import matplotlib | |
from PIL import Image | |
# Load model configurations and initialize model | |
def load_model(config_path, model_path, encoder, device): | |
with open(config_path) as f: | |
config = json.load(f) | |
model = UniDepthV2(config) | |
model.load_state_dict(torch.load(model_path, map_location=device)['model'], strict=True) | |
model = model.to(device).eval() | |
return model | |
# Inference function | |
def depth_estimation(image, model_path, encoder='vits'): | |
try: | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
config_path = 'configs/config_v2_vits14.json' | |
# Ensure model path exists or download if needed | |
if not os.path.exists(model_path): | |
return "Model checkpoint not found. Please upload a valid model path." | |
model = load_model(config_path, model_path, encoder, device) | |
# Preprocess image | |
rgb = torch.from_numpy(np.array(image)).permute(2, 0, 1).to(device) # C, H, W | |
predictions = model.infer(rgb) | |
depth = predictions["depth"].squeeze().to('cpu').numpy() | |
min_depth = depth.min() | |
max_depth = depth.max() | |
depth_normalized = (depth - min_depth) / (max_depth - min_depth) | |
# Apply colormap | |
cmap = matplotlib.colormaps.get_cmap('Spectral') | |
depth_color = (cmap(depth_normalized)[:, :, :3] * 255).astype(np.uint8) | |
# Create a figure and axis for the colorbar | |
fig, ax = plt.subplots(figsize=(6, 0.4)) | |
fig.subplots_adjust(bottom=0.5) | |
# Create a colorbar | |
norm = matplotlib.colors.Normalize(vmin=min_depth, vmax=max_depth) | |
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) | |
sm.set_array([]) | |
cbar = fig.colorbar(sm, cax=ax, orientation='horizontal', label='Depth (meters)') | |
# Save the colorbar to a BytesIO object | |
from io import BytesIO | |
buf = BytesIO() | |
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1) | |
plt.close(fig) | |
buf.seek(0) | |
# Open the colorbar image | |
colorbar_img = Image.open(buf) | |
# Create a new image with space for the colorbar | |
new_height = depth_color.shape[0] + colorbar_img.size[1] | |
new_img = Image.new('RGB', (depth_color.shape[1], new_height), (255, 255, 255)) | |
# Paste the depth image and colorbar | |
new_img.paste(Image.fromarray(depth_color), (0, 0)) | |
new_img.paste(colorbar_img, (0, depth_color.shape[0])) | |
return new_img | |
except Exception as e: | |
return f"Error occurred: {str(e)}" | |
# Gradio Interface | |
def main(): | |
iface = gr.Interface( | |
fn=depth_estimation, | |
inputs=[ | |
gr.Image(type="numpy", label="Input Image"), | |
gr.Textbox(value='checkpoint/latest.pth', label='Model Path'), | |
gr.Dropdown(choices=['vits', 'vitb', 'vitl', 'vitg'], value='vits', label='Encoder'), | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Predicted Depth") | |
], | |
title="Depth Anything V2 Metric Depth Estimation", | |
description="Upload an image to get its estimated depth map using Depth Anything V2.", | |
) | |
iface.launch() | |
if __name__ == "__main__": | |
main() | |