qubvel-hf's picture
qubvel-hf HF staff
Fix
1837eda
raw
history blame contribute delete
No virus
5.51 kB
import os
import torch
import spaces
import matplotlib
import numpy as np
import gradio as gr
from PIL import Image
from transformers import pipeline
from huggingface_hub import hf_hub_download
from gradio_imageslider import ImageSlider
from depth_anything_v2.dpt import DepthAnythingV2
from loguru import logger
css = """
#img-display-container {
max-height: 100vh;
}
#img-display-input {
max-height: 80vh;
}
#img-display-output {
max-height: 80vh;
}
#download {
height: 62px;
}
"""
title = "# Depth Anything: Watch V1 and V2 side by side."
description1 = """Please refer to **Depth Anything V2** [paper](https://arxiv.org/abs/2406.09414) for more details."""
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEFAULT_V2_MODEL_NAME = "Base"
DEFAULT_V1_MODEL_NAME = "Base"
cmap = matplotlib.colormaps.get_cmap('Spectral_r')
# --------------------------------------------------------------------
# Depth anything V1 configuration
# --------------------------------------------------------------------
depth_anything_v1_name2checkpoint = {
"Small": "LiheYoung/depth-anything-small-hf",
"Base": "LiheYoung/depth-anything-base-hf",
"Large": "LiheYoung/depth-anything-large-hf",
}
depth_anything_v1_pipelines = {}
# --------------------------------------------------------------------
# Depth anything V2 configuration
# --------------------------------------------------------------------
depth_anything_v2_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
depth_anything_v2_encoder2name = {
'vits': 'Small',
'vitb': 'Base',
'vitl': 'Large',
# 'vitg': 'Giant', # we are undergoing company review procedures to release our giant model checkpoint
}
depth_anything_v2_name2encoder = {v: k for k, v in depth_anything_v2_encoder2name.items()}
depth_anything_v2_models = {}
# --------------------------------------------------------------------
def get_v1_pipe(model_name):
return pipeline(task="depth-estimation", model=depth_anything_v1_name2checkpoint[model_name], device=DEVICE)
def get_v2_model(model_name):
encoder = depth_anything_v2_name2encoder[model_name]
model = DepthAnythingV2(**depth_anything_v2_configs[encoder])
filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-{model_name}", filename=f"depth_anything_v2_{encoder}.pth", repo_type="model")
state_dict = torch.load(filepath, map_location="cpu")
model.load_state_dict(state_dict)
model = model.to(DEVICE).eval()
return model
def predict_depth_v1(image, model_name):
if model_name not in depth_anything_v1_pipelines:
depth_anything_v1_pipelines[model_name] = get_v1_pipe(model_name)
pipe = depth_anything_v1_pipelines[model_name]
return pipe(image)
def predict_depth_v2(image, model_name):
if model_name not in depth_anything_v2_models:
depth_anything_v2_models[model_name] = get_v2_model(model_name)
model = depth_anything_v2_models[model_name].cuda()
return model.infer_image(image)
def compute_depth_map_v2(image, model_select: str):
depth = predict_depth_v2(image[:, :, ::-1], model_select)
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.astype(np.uint8)
colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
return colored_depth
def compute_depth_map_v1(image, model_select):
pil_image = Image.fromarray(image)
depth = predict_depth_v1(pil_image, model_select)
depth = np.array(depth["depth"]).astype(np.uint8)
colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
return colored_depth
@spaces.GPU
@torch.no_grad()
def on_submit(image, model_v1_select, model_v2_select):
logger.info(f"Computing depth for V1 model: {model_v1_select} and V2 model: {model_v2_select}")
colored_depth_v1 = compute_depth_map_v1(image, model_v1_select)
colored_depth_v2 = compute_depth_map_v2(image, model_v2_select)
return colored_depth_v1, colored_depth_v2
with gr.Blocks(css=css) as demo:
gr.Markdown(title)
gr.Markdown(description1)
gr.Markdown("### Depth Prediction demo")
with gr.Row():
model_select_v1 = gr.Dropdown(label="Depth Anything V1 Model", choices=list(depth_anything_v1_name2checkpoint.keys()), value=DEFAULT_V1_MODEL_NAME)
model_select_v2 = gr.Dropdown(label="Depth Anything V2 Model", choices=list(depth_anything_v2_encoder2name.values()), value=DEFAULT_V2_MODEL_NAME)
with gr.Row():
gr.Markdown()
gr.Markdown("Depth Maps: V1 <-> V2")
with gr.Row():
input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
depth_image_slider = ImageSlider(elem_id='img-display-output', position=0.5)
submit = gr.Button(value="Compute Depth")
submit.click(on_submit, inputs=[input_image, model_select_v1, model_select_v2], outputs=[depth_image_slider])
example_files = os.listdir('assets/examples')
example_files.sort()
example_files = [os.path.join('assets/examples', filename) for filename in example_files]
examples = gr.Examples(examples=example_files, inputs=[input_image])
if __name__ == '__main__':
demo.queue().launch(share=True)