smhh24's picture
Upload 90 files
560b597 verified
raw
history blame
3.39 kB
import gradio as gr
import torch
import cv2
import numpy as np
import json
from unidepth.models import UniDepthV2
import os
import matplotlib.pyplot as plt
import matplotlib
from PIL import Image
# Load model configurations and initialize model
def load_model(config_path, model_path, encoder, device):
with open(config_path) as f:
config = json.load(f)
model = UniDepthV2(config)
model.load_state_dict(torch.load(model_path, map_location=device)['model'], strict=True)
model = model.to(device).eval()
return model
# Inference function
def depth_estimation(image, model_path, encoder='vits'):
try:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config_path = 'configs/config_v2_vits14.json'
# Ensure model path exists or download if needed
if not os.path.exists(model_path):
return "Model checkpoint not found. Please upload a valid model path."
model = load_model(config_path, model_path, encoder, device)
# Preprocess image
rgb = torch.from_numpy(np.array(image)).permute(2, 0, 1).to(device) # C, H, W
predictions = model.infer(rgb)
depth = predictions["depth"].squeeze().to('cpu').numpy()
min_depth = depth.min()
max_depth = depth.max()
depth_normalized = (depth - min_depth) / (max_depth - min_depth)
# Apply colormap
cmap = matplotlib.colormaps.get_cmap('Spectral')
depth_color = (cmap(depth_normalized)[:, :, :3] * 255).astype(np.uint8)
# Create a figure and axis for the colorbar
fig, ax = plt.subplots(figsize=(6, 0.4))
fig.subplots_adjust(bottom=0.5)
# Create a colorbar
norm = matplotlib.colors.Normalize(vmin=min_depth, vmax=max_depth)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, cax=ax, orientation='horizontal', label='Depth (meters)')
# Save the colorbar to a BytesIO object
from io import BytesIO
buf = BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
plt.close(fig)
buf.seek(0)
# Open the colorbar image
colorbar_img = Image.open(buf)
# Create a new image with space for the colorbar
new_height = depth_color.shape[0] + colorbar_img.size[1]
new_img = Image.new('RGB', (depth_color.shape[1], new_height), (255, 255, 255))
# Paste the depth image and colorbar
new_img.paste(Image.fromarray(depth_color), (0, 0))
new_img.paste(colorbar_img, (0, depth_color.shape[0]))
return new_img
except Exception as e:
return f"Error occurred: {str(e)}"
# Gradio Interface
def main():
iface = gr.Interface(
fn=depth_estimation,
inputs=[
gr.Image(type="numpy", label="Input Image"),
gr.Textbox(value='checkpoint/latest.pth', label='Model Path'),
gr.Dropdown(choices=['vits', 'vitb', 'vitl', 'vitg'], value='vits', label='Encoder'),
],
outputs=[
gr.Image(type="pil", label="Predicted Depth")
],
title="Depth Anything V2 Metric Depth Estimation",
description="Upload an image to get its estimated depth map using Depth Anything V2.",
)
iface.launch()
if __name__ == "__main__":
main()