File size: 3,111 Bytes
81f4d3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Copyright (C) 2023 Deforum LLC
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

# Contact the authors: https://deforum.github.io/

import torch
import cv2
import os
import numpy as np
import torchvision.transforms as transforms
from .general_utils import download_file_with_checksum
from leres.lib.multi_depth_model_woauxi import RelDepthModel
from leres.lib.net_tools import load_ckpt
    
class LeReSDepth:
    def __init__(self, width=448, height=448, models_path=None, checkpoint_name='res101.pth', backbone='resnext101'):
        self.width = width
        self.height = height
        self.models_path = models_path
        self.checkpoint_name = checkpoint_name
        self.backbone = backbone

        download_file_with_checksum(url='https://cloudstor.aarnet.edu.au/plus/s/lTIJF4vrvHCAI31/download', expected_checksum='7fdc870ae6568cb28d56700d0be8fc45541e09cea7c4f84f01ab47de434cfb7463cacae699ad19fe40ee921849f9760dedf5e0dec04a62db94e169cf203f55b1', dest_folder=models_path, dest_filename=self.checkpoint_name)

        self.depth_model = RelDepthModel(backbone=self.backbone)
        self.depth_model.eval()
        self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        self.depth_model.to(self.DEVICE)
        load_ckpt(os.path.join(self.models_path, self.checkpoint_name), self.depth_model, None, None)

    @staticmethod
    def scale_torch(img):
        if len(img.shape) == 2:
            img = img[np.newaxis, :, :]
        if img.shape[2] == 3:
            transform = transforms.Compose([transforms.ToTensor(),
                                            transforms.Normalize((0.485, 0.456, 0.406) , (0.229, 0.224, 0.225))])
            img = transform(img)
        else:
            img = img.astype(np.float32)
            img = torch.from_numpy(img)
        return img

    def predict(self, image):
        resized_image = cv2.resize(image, (self.width, self.height))
        img_torch = self.scale_torch(resized_image)[None, :, :, :]
        pred_depth = self.depth_model.inference(img_torch).cpu().numpy().squeeze()
        pred_depth_ori = cv2.resize(pred_depth, (image.shape[1], image.shape[0]))
        return torch.from_numpy(pred_depth_ori).unsqueeze(0).to(self.DEVICE)

    def save_raw_depth(self, depth, filepath):
        depth_normalized = (depth / depth.max() * 60000).astype(np.uint16)
        cv2.imwrite(filepath, depth_normalized)
        
    def to(self, device):
        self.DEVICE = device
        self.depth_model = self.depth_model.to(device)

    def delete(self):
        del self.depth_model