Sentdex's picture
Upload app.py
8286d35
raw
history blame contribute delete
No virus
2.2 kB
import cv2
import torch
import gradio as gr
import numpy as np
from PIL import Image
import time
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
use_large_model = True
if use_large_model:
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
else:
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small")
device = "cpu"
midas.to(device)
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
if use_large_model:
transform = midas_transforms.default_transform
else:
transform = midas_transforms.small_transform
def depth(img):
original_image = img
cv_image = np.array(img)
img = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
input_batch = transform(img).to(device)
with torch.no_grad():
prediction = midas(input_batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze()
output = prediction.cpu().numpy()
formatted = (output * 255 / np.max(output)).astype('uint8')
img = Image.fromarray(formatted)
# create new image with with original_image and img side by side
new_im = Image.new('RGB', (original_image.width * 2, original_image.height))
new_im.paste(original_image, (0,0))
new_im.paste(img, (original_image.width,0))
# save the image to a file: (removed for hosting on HF)
#new_im.save(f'RGBDs/{int(time.time())}_RGBD.png')
return new_im
inputs = gr.inputs.Image(type='pil', label="Original Image")
outputs = gr.outputs.Image(type="pil",label="Output Image")
title = "RGB to RGBD for Looking Glass (using MiDaS)"
description = "Takes an RGB image and creates the depth + combines to the RGB image. Depth is predicted by MiDaS. This is a demo of the Looking Glass. For more information, visit https://lookingglassfactory.com"
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1907.01341v3'>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer</a> | <a href='https://github.com/intel-isl/MiDaS'>Github Repo</a></p>"
gr.Interface(depth, inputs, outputs, title=title, description=description, article=article).launch()