def1 / scripts /deforum_helpers /depth_leres.py
ddoc's picture
Upload 188 files
81f4d3a
# Copyright (C) 2023 Deforum LLC
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
# Contact the authors: https://deforum.github.io/
import torch
import cv2
import os
import numpy as np
import torchvision.transforms as transforms
from .general_utils import download_file_with_checksum
from leres.lib.multi_depth_model_woauxi import RelDepthModel
from leres.lib.net_tools import load_ckpt
class LeReSDepth:
def __init__(self, width=448, height=448, models_path=None, checkpoint_name='res101.pth', backbone='resnext101'):
self.width = width
self.height = height
self.models_path = models_path
self.checkpoint_name = checkpoint_name
self.backbone = backbone
download_file_with_checksum(url='https://cloudstor.aarnet.edu.au/plus/s/lTIJF4vrvHCAI31/download', expected_checksum='7fdc870ae6568cb28d56700d0be8fc45541e09cea7c4f84f01ab47de434cfb7463cacae699ad19fe40ee921849f9760dedf5e0dec04a62db94e169cf203f55b1', dest_folder=models_path, dest_filename=self.checkpoint_name)
self.depth_model = RelDepthModel(backbone=self.backbone)
self.depth_model.eval()
self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
self.depth_model.to(self.DEVICE)
load_ckpt(os.path.join(self.models_path, self.checkpoint_name), self.depth_model, None, None)
@staticmethod
def scale_torch(img):
if len(img.shape) == 2:
img = img[np.newaxis, :, :]
if img.shape[2] == 3:
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406) , (0.229, 0.224, 0.225))])
img = transform(img)
else:
img = img.astype(np.float32)
img = torch.from_numpy(img)
return img
def predict(self, image):
resized_image = cv2.resize(image, (self.width, self.height))
img_torch = self.scale_torch(resized_image)[None, :, :, :]
pred_depth = self.depth_model.inference(img_torch).cpu().numpy().squeeze()
pred_depth_ori = cv2.resize(pred_depth, (image.shape[1], image.shape[0]))
return torch.from_numpy(pred_depth_ori).unsqueeze(0).to(self.DEVICE)
def save_raw_depth(self, depth, filepath):
depth_normalized = (depth / depth.max() * 60000).astype(np.uint16)
cv2.imwrite(filepath, depth_normalized)
def to(self, device):
self.DEVICE = device
self.depth_model = self.depth_model.to(device)
def delete(self):
del self.depth_model