Spaces:
Running
on
A10G
Running
on
A10G
File size: 4,035 Bytes
bcec54e 311c16e ed2df5e 5cc1836 bcec54e 8414e4e bcec54e ee23f4b bcec54e 8f4661f bcec54e 8414e4e bcec54e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'depth')))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'stable-diffusion')))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'taming-transformers')))
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), 'depth')))
import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from depth.models_depth.model import EVPDepth
from depth.configs.train_options import TrainOptions
from depth.configs.test_options import TestOptions
import glob
import utils
import torchvision.transforms as transforms
from utils_depth.misc import colorize
from PIL import Image
import torch.nn.functional as F
import gradio as gr
import tempfile
css = """
#img-display-container {
max-height: 50vh;
}
#img-display-input {
max-height: 40vh;
}
#img-display-output {
max-height: 40vh;
}
"""
def create_demo(model, device):
gr.Markdown("### Depth Prediction demo")
with gr.Row():
input_image = gr.Image(label="Input Image", type='pil', elem_id='img-display-input')
depth_image = gr.Image(label="Depth Map", elem_id='img-display-output')
raw_file = gr.File(label="16-bit raw depth, multiplier:256")
submit = gr.Button("Submit")
def on_submit(image):
transform = transforms.ToTensor()
image = transform(image).unsqueeze(0).to(device)
shape = image.shape
image = torch.nn.functional.interpolate(image, (440,480), mode='bilinear', align_corners=True)
image = F.pad(image, (0, 0, 40, 0))
with torch.no_grad():
pred = model(image)['pred_d']
pred = pred[:,:,40:,:]
pred = torch.nn.functional.interpolate(pred, shape[2:], mode='bilinear', align_corners=True)
pred_d_numpy = pred.squeeze().cpu().numpy()
colored_depth, _, _ = colorize(pred_d_numpy, cmap='gray_r')
tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
raw_depth = Image.fromarray((pred_d_numpy*256).astype('uint16'))
raw_depth.save(tmp.name)
return [colored_depth, tmp.name]
submit.click(on_submit, inputs=[input_image], outputs=[depth_image, raw_file])
examples = gr.Examples(examples=["imgs/test_img1.jpg", "imgs/test_img2.jpg", "imgs/test_img3.jpg", "imgs/test_img4.jpg"],
inputs=[input_image])
def main():
opt = TestOptions().initialize()
args = opt.parse_args()
args.ckpt_dir = 'best_model_nyu.ckpt'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EVPDepth(args=args, caption_aggregation=True)
cudnn.benchmark = True
model.to(device)
model_weight = torch.load(args.ckpt_dir, map_location=device)['model']
if 'module' in next(iter(model_weight.items()))[0]:
model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
model.load_state_dict(model_weight, strict=False)
model.eval()
title = "# EVP"
description = """Official demo for **EVP: Enhanced Visual Perception using Inverse Multi-Attentive Feature
Refinement and Regularized Image-Text Alignment**.
EVP is a deep learning model for metric depth estimation from a single image.
Please refer to our [paper](https://arxiv.org/abs/2312.08548) or [github](https://github.com/Lavreniuk/EVP) for more details."""
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Tab("Depth Prediction"):
create_demo(model, device)
gr.HTML('''<br><br><br><center>You can duplicate this Space to skip the queue:<a href="https://huggingface.co/spaces/MykolaL/evp?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a><br>
<p><img src="https://visitor-badge.glitch.me/badge?page_id=MykolaL/evp" alt="visitors"></p></center>''')
demo.queue().launch(share=True)
if __name__ == '__main__':
main()
|