File size: 8,474 Bytes
81f4d3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright (C) 2023 Deforum LLC
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

# Contact the authors: https://deforum.github.io/

import gc
import cv2
import numpy as np
import torch
from PIL import Image
from einops import rearrange, repeat
from modules import devices
from modules.shared import cmd_opts
from .depth_adabins import AdaBinsModel
from .depth_leres import LeReSDepth
from .depth_midas import MidasDepth
from .depth_zoe import ZoeDepth
from .general_utils import debug_print

class DepthModel:
    _instance = None

    def __new__(cls, *args, **kwargs):
        keep_in_vram = kwargs.get('keep_in_vram', False)
        depth_algorithm = kwargs.get('depth_algorithm', 'Midas-3-Hybrid')
        Width, Height = kwargs.get('Width', 512), kwargs.get('Height', 512)
        midas_weight = kwargs.get('midas_weight', 0.2)
        model_switched = cls._instance and cls._instance.depth_algorithm != depth_algorithm
        resolution_changed = cls._instance and (cls._instance.Width != Width or cls._instance.Height != Height)
        zoe_algorithm = 'zoe' in depth_algorithm.lower()
        model_deleted = cls._instance and cls._instance.should_delete

        should_reload = (cls._instance is None or model_deleted or model_switched or (zoe_algorithm and resolution_changed))

        if should_reload:
            cls._instance = super().__new__(cls)
            cls._instance._initialize(models_path=args[0], device=args[1], half_precision=not cmd_opts.no_half, keep_in_vram=keep_in_vram, depth_algorithm=depth_algorithm, Width=Width, Height=Height, midas_weight=midas_weight)
        elif cls._instance.should_delete and keep_in_vram:
            cls._instance._initialize(models_path=args[0], device=args[1], half_precision=not cmd_opts.no_half, keep_in_vram=keep_in_vram, depth_algorithm=depth_algorithm, Width=Width, Height=Height, midas_weight=midas_weight)
        cls._instance.should_delete = not keep_in_vram
        return cls._instance

    def _initialize(self, models_path, device, half_precision=not cmd_opts.no_half, keep_in_vram=False, depth_algorithm='Midas-3-Hybrid', Width=512, Height=512, midas_weight=1.0):
        self.models_path = models_path
        self.device = device
        self.half_precision = half_precision
        self.keep_in_vram = keep_in_vram
        self.depth_algorithm = depth_algorithm
        self.Width, self.Height = Width, Height
        self.midas_weight = midas_weight
        self.depth_min, self.depth_max = 1000, -1000
        self.adabins_helper = None
        self._initialize_model()

    def _initialize_model(self):
        depth_algo = self.depth_algorithm.lower()
        if depth_algo.startswith('zoe'):
            self.zoe_depth = ZoeDepth(self.Width, self.Height)
            if depth_algo == 'zoe+adabins (old)':
                self.adabins_model = AdaBinsModel(self.models_path, keep_in_vram=self.keep_in_vram)
                self.adabins_helper = self.adabins_model.adabins_helper
        elif depth_algo == 'leres':
            self.leres_depth = LeReSDepth(width=448, height=448, models_path=self.models_path, checkpoint_name='res101.pth', backbone='resnext101')
        elif depth_algo == 'adabins':
            self.adabins_model = AdaBinsModel(self.models_path, keep_in_vram=self.keep_in_vram)
            self.adabins_helper = self.adabins_model.adabins_helper
        elif depth_algo.startswith('midas'):
            self.midas_depth = MidasDepth(self.models_path, self.device, half_precision=self.half_precision, midas_model_type=self.depth_algorithm)
            if depth_algo == 'midas+adabins (old)':
                self.adabins_model = AdaBinsModel(self.models_path, keep_in_vram=self.keep_in_vram)
                self.adabins_helper = self.adabins_model.adabins_helper
        else:
            raise Exception(f"Unknown depth_algorithm: {self.depth_algorithm}")

    def predict(self, prev_img_cv2, midas_weight, half_precision) -> torch.Tensor:

        img_pil = Image.fromarray(cv2.cvtColor(prev_img_cv2.astype(np.uint8), cv2.COLOR_RGB2BGR))

        if self.depth_algorithm.lower().startswith('zoe'):
            depth_tensor = self.zoe_depth.predict(img_pil).to(self.device)
            if self.depth_algorithm.lower() == 'zoe+adabins (old)' and midas_weight < 1.0:
                use_adabins, adabins_depth = AdaBinsModel._instance.predict(img_pil, prev_img_cv2)
                if use_adabins: # if there was no error in getting the adabins depth, align midas with adabins
                    depth_tensor = self.blend_and_align_with_adabins(depth_tensor, adabins_depth, midas_weight)
        elif self.depth_algorithm.lower() == 'leres':
            depth_tensor = self.leres_depth.predict(prev_img_cv2.astype(np.float32) / 255.0)
        elif self.depth_algorithm.lower() == 'adabins':
            use_adabins, adabins_depth = AdaBinsModel._instance.predict(img_pil, prev_img_cv2)
            depth_tensor = torch.tensor(adabins_depth)
            if use_adabins is False:
                raise Exception("Error getting depth from AdaBins") # TODO: fallback to something else maybe?
        elif self.depth_algorithm.lower().startswith('midas'):
            depth_tensor = self.midas_depth.predict(prev_img_cv2, half_precision)
            if self.depth_algorithm.lower() == 'midas+adabins (old)' and midas_weight < 1.0:
                use_adabins, adabins_depth = AdaBinsModel._instance.predict(img_pil, prev_img_cv2)
                if use_adabins: # if there was no error in getting the adabins depth, align midas with adabins
                    depth_tensor = self.blend_and_align_with_adabins(depth_tensor, adabins_depth, midas_weight)
        else: # Unknown!
            raise Exception(f"Unknown depth_algorithm passed to depth.predict function: {self.depth_algorithm}")

        return depth_tensor

    def blend_and_align_with_adabins(self, depth_tensor, adabins_depth, midas_weight):
        depth_tensor = torch.subtract(50.0, depth_tensor) / 19.0 # align midas depth with adabins depth. Original alignment code from Disco Diffusion
        blended_depth_map = (depth_tensor.cpu().numpy() * midas_weight + adabins_depth * (1.0 - midas_weight))
        depth_tensor = torch.from_numpy(np.expand_dims(blended_depth_map, axis=0)).squeeze().to(self.device)
        debug_print(f"Blended Midas Depth with AdaBins Depth")
        return depth_tensor
        
    def to(self, device):
        self.device = device
        if self.depth_algorithm.lower().startswith('zoe'):
            self.zoe_depth.zoe.to(device)
        elif self.depth_algorithm.lower() == 'leres':
            self.leres_depth.to(device)
        elif self.depth_algorithm.lower().startswith('midas'):
            self.midas_depth.to(device)
        if hasattr(self, 'adabins_model'):
            self.adabins_model.to(device)
        gc.collect()
        torch.cuda.empty_cache()
        
    def to_image(self, depth: torch.Tensor):
        depth = depth.cpu().numpy()
        depth = np.expand_dims(depth, axis=0) if len(depth.shape) == 2 else depth
        self.depth_min, self.depth_max = min(self.depth_min, depth.min()), max(self.depth_max, depth.max())
        denom = max(1e-8, self.depth_max - self.depth_min)
        temp = rearrange((depth - self.depth_min) / denom * 255, 'c h w -> h w c')
        return Image.fromarray(repeat(temp, 'h w 1 -> h w c', c=3).astype(np.uint8))

    def save(self, filename: str, depth: torch.Tensor):
        self.to_image(depth).save(filename)

    def delete_model(self):
        for attr in ['zoe_depth', 'leres_depth']:
            if hasattr(self, attr):
                getattr(self, attr).delete()
                delattr(self, attr)

        if hasattr(self, 'midas_depth'):
            del self.midas_depth

        if hasattr(self, 'adabins_model'):
            self.adabins_model.delete_model()

        gc.collect()
        torch.cuda.empty_cache()
        devices.torch_gc()