|
from typing import Optional, Union |
|
from tqdm.auto import trange |
|
from PIL import ImageOps |
|
from PIL import Image |
|
from torch import nn |
|
import numpy as np |
|
import torch |
|
import cv2 |
|
|
|
|
|
class MidasDepth(nn.Module): |
|
def __init__(self, model_type="DPT_Large", |
|
device=torch.device( |
|
"cuda" if torch.cuda.is_available() else "cpu"), |
|
is_inpainting=False): |
|
super().__init__() |
|
self.device = device |
|
if self.device.type == "mps": |
|
self.device = torch.device("cpu") |
|
self.model = torch.hub.load( |
|
"intel-isl/MiDaS", model_type).to(self.device).eval().requires_grad_(False) |
|
self.transform = torch.hub.load( |
|
"intel-isl/MiDaS", "transforms").dpt_transform |
|
|
|
@torch.no_grad() |
|
def forward(self, image): |
|
if torch.is_tensor(image): |
|
image = image.cpu().detach() |
|
if not isinstance(image, np.ndarray): |
|
image = np.asarray(image) |
|
image = image.squeeze() |
|
batch = self.transform(image).to(self.device) |
|
prediction = self.model(batch) |
|
prediction = torch.nn.functional.interpolate( |
|
prediction.unsqueeze(1), |
|
size=image.shape[-3:-1], |
|
mode="bicubic", |
|
align_corners=False, |
|
)[:, 0] |
|
|
|
|
|
return prediction |
|
|
|
@torch.no_grad() |
|
def get_depth(self, img): |
|
im = torch.from_numpy(np.asarray(img)).float().to(self.device) / 255. |
|
og_depth = self(im.unsqueeze(0) * 255.)[0] |
|
d = og_depth |
|
d = (d - d.min()) / (d.max() - d.min()) * (10 - 3) + 3 |
|
d = 30 / d |
|
|
|
|
|
|
|
return d.detach().cpu().numpy() |
|
|
|
|
|
if __name__ == "__main__": |
|
from matplotlib import pyplot as plt |
|
plt.imshow(MidasDepth().get_depth(Image.open("horse.jpg"))) |
|
plt.show() |
|
|