depthapi / depth.py
nev's picture
Ininitial commit
3fce28b
from typing import Optional, Union
from tqdm.auto import trange
from PIL import ImageOps
from PIL import Image
from torch import nn
import numpy as np
import torch
import cv2
class MidasDepth(nn.Module):
def __init__(self, model_type="DPT_Large",
device=torch.device(
"cuda" if torch.cuda.is_available() else "cpu"),
is_inpainting=False):
super().__init__()
self.device = device
if self.device.type == "mps":
self.device = torch.device("cpu")
self.model = torch.hub.load(
"intel-isl/MiDaS", model_type).to(self.device).eval().requires_grad_(False)
self.transform = torch.hub.load(
"intel-isl/MiDaS", "transforms").dpt_transform
@torch.no_grad()
def forward(self, image):
if torch.is_tensor(image):
image = image.cpu().detach()
if not isinstance(image, np.ndarray):
image = np.asarray(image)
image = image.squeeze()
batch = self.transform(image).to(self.device)
prediction = self.model(batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=image.shape[-3:-1],
mode="bicubic",
align_corners=False,
)[:, 0]
# prediction = prediction - prediction.min() + 1.5
# prediction = 20 / prediction
return prediction # .squeeze()
@torch.no_grad()
def get_depth(self, img):
im = torch.from_numpy(np.asarray(img)).float().to(self.device) / 255.
og_depth = self(im.unsqueeze(0) * 255.)[0]
d = og_depth
d = (d - d.min()) / (d.max() - d.min()) * (10 - 3) + 3
d = 30 / d
# d = d.max() - d
# d = d / d.max() * 15
# d = d + 1.5
return d.detach().cpu().numpy()
if __name__ == "__main__":
from matplotlib import pyplot as plt
plt.imshow(MidasDepth().get_depth(Image.open("horse.jpg")))
plt.show()