from typing import Optional, Union from tqdm.auto import trange from PIL import ImageOps from PIL import Image from torch import nn import numpy as np import torch import cv2 class MidasDepth(nn.Module): def __init__(self, model_type="DPT_Large", device=torch.device( "cuda" if torch.cuda.is_available() else "cpu"), is_inpainting=False): super().__init__() self.device = device if self.device.type == "mps": self.device = torch.device("cpu") self.model = torch.hub.load( "intel-isl/MiDaS", model_type).to(self.device).eval().requires_grad_(False) self.transform = torch.hub.load( "intel-isl/MiDaS", "transforms").dpt_transform @torch.no_grad() def forward(self, image): if torch.is_tensor(image): image = image.cpu().detach() if not isinstance(image, np.ndarray): image = np.asarray(image) image = image.squeeze() batch = self.transform(image).to(self.device) prediction = self.model(batch) prediction = torch.nn.functional.interpolate( prediction.unsqueeze(1), size=image.shape[-3:-1], mode="bicubic", align_corners=False, )[:, 0] # prediction = prediction - prediction.min() + 1.5 # prediction = 20 / prediction return prediction # .squeeze() @torch.no_grad() def get_depth(self, img): im = torch.from_numpy(np.asarray(img)).float().to(self.device) / 255. og_depth = self(im.unsqueeze(0) * 255.)[0] d = og_depth d = (d - d.min()) / (d.max() - d.min()) * (10 - 3) + 3 d = 30 / d # d = d.max() - d # d = d / d.max() * 15 # d = d + 1.5 return d.detach().cpu().numpy() if __name__ == "__main__": from matplotlib import pyplot as plt plt.imshow(MidasDepth().get_depth(Image.open("horse.jpg"))) plt.show()