File size: 8,474 Bytes
81f4d3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
# Copyright (C) 2023 Deforum LLC
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
# Contact the authors: https://deforum.github.io/
import gc
import cv2
import numpy as np
import torch
from PIL import Image
from einops import rearrange, repeat
from modules import devices
from modules.shared import cmd_opts
from .depth_adabins import AdaBinsModel
from .depth_leres import LeReSDepth
from .depth_midas import MidasDepth
from .depth_zoe import ZoeDepth
from .general_utils import debug_print
class DepthModel:
_instance = None
def __new__(cls, *args, **kwargs):
keep_in_vram = kwargs.get('keep_in_vram', False)
depth_algorithm = kwargs.get('depth_algorithm', 'Midas-3-Hybrid')
Width, Height = kwargs.get('Width', 512), kwargs.get('Height', 512)
midas_weight = kwargs.get('midas_weight', 0.2)
model_switched = cls._instance and cls._instance.depth_algorithm != depth_algorithm
resolution_changed = cls._instance and (cls._instance.Width != Width or cls._instance.Height != Height)
zoe_algorithm = 'zoe' in depth_algorithm.lower()
model_deleted = cls._instance and cls._instance.should_delete
should_reload = (cls._instance is None or model_deleted or model_switched or (zoe_algorithm and resolution_changed))
if should_reload:
cls._instance = super().__new__(cls)
cls._instance._initialize(models_path=args[0], device=args[1], half_precision=not cmd_opts.no_half, keep_in_vram=keep_in_vram, depth_algorithm=depth_algorithm, Width=Width, Height=Height, midas_weight=midas_weight)
elif cls._instance.should_delete and keep_in_vram:
cls._instance._initialize(models_path=args[0], device=args[1], half_precision=not cmd_opts.no_half, keep_in_vram=keep_in_vram, depth_algorithm=depth_algorithm, Width=Width, Height=Height, midas_weight=midas_weight)
cls._instance.should_delete = not keep_in_vram
return cls._instance
def _initialize(self, models_path, device, half_precision=not cmd_opts.no_half, keep_in_vram=False, depth_algorithm='Midas-3-Hybrid', Width=512, Height=512, midas_weight=1.0):
self.models_path = models_path
self.device = device
self.half_precision = half_precision
self.keep_in_vram = keep_in_vram
self.depth_algorithm = depth_algorithm
self.Width, self.Height = Width, Height
self.midas_weight = midas_weight
self.depth_min, self.depth_max = 1000, -1000
self.adabins_helper = None
self._initialize_model()
def _initialize_model(self):
depth_algo = self.depth_algorithm.lower()
if depth_algo.startswith('zoe'):
self.zoe_depth = ZoeDepth(self.Width, self.Height)
if depth_algo == 'zoe+adabins (old)':
self.adabins_model = AdaBinsModel(self.models_path, keep_in_vram=self.keep_in_vram)
self.adabins_helper = self.adabins_model.adabins_helper
elif depth_algo == 'leres':
self.leres_depth = LeReSDepth(width=448, height=448, models_path=self.models_path, checkpoint_name='res101.pth', backbone='resnext101')
elif depth_algo == 'adabins':
self.adabins_model = AdaBinsModel(self.models_path, keep_in_vram=self.keep_in_vram)
self.adabins_helper = self.adabins_model.adabins_helper
elif depth_algo.startswith('midas'):
self.midas_depth = MidasDepth(self.models_path, self.device, half_precision=self.half_precision, midas_model_type=self.depth_algorithm)
if depth_algo == 'midas+adabins (old)':
self.adabins_model = AdaBinsModel(self.models_path, keep_in_vram=self.keep_in_vram)
self.adabins_helper = self.adabins_model.adabins_helper
else:
raise Exception(f"Unknown depth_algorithm: {self.depth_algorithm}")
def predict(self, prev_img_cv2, midas_weight, half_precision) -> torch.Tensor:
img_pil = Image.fromarray(cv2.cvtColor(prev_img_cv2.astype(np.uint8), cv2.COLOR_RGB2BGR))
if self.depth_algorithm.lower().startswith('zoe'):
depth_tensor = self.zoe_depth.predict(img_pil).to(self.device)
if self.depth_algorithm.lower() == 'zoe+adabins (old)' and midas_weight < 1.0:
use_adabins, adabins_depth = AdaBinsModel._instance.predict(img_pil, prev_img_cv2)
if use_adabins: # if there was no error in getting the adabins depth, align midas with adabins
depth_tensor = self.blend_and_align_with_adabins(depth_tensor, adabins_depth, midas_weight)
elif self.depth_algorithm.lower() == 'leres':
depth_tensor = self.leres_depth.predict(prev_img_cv2.astype(np.float32) / 255.0)
elif self.depth_algorithm.lower() == 'adabins':
use_adabins, adabins_depth = AdaBinsModel._instance.predict(img_pil, prev_img_cv2)
depth_tensor = torch.tensor(adabins_depth)
if use_adabins is False:
raise Exception("Error getting depth from AdaBins") # TODO: fallback to something else maybe?
elif self.depth_algorithm.lower().startswith('midas'):
depth_tensor = self.midas_depth.predict(prev_img_cv2, half_precision)
if self.depth_algorithm.lower() == 'midas+adabins (old)' and midas_weight < 1.0:
use_adabins, adabins_depth = AdaBinsModel._instance.predict(img_pil, prev_img_cv2)
if use_adabins: # if there was no error in getting the adabins depth, align midas with adabins
depth_tensor = self.blend_and_align_with_adabins(depth_tensor, adabins_depth, midas_weight)
else: # Unknown!
raise Exception(f"Unknown depth_algorithm passed to depth.predict function: {self.depth_algorithm}")
return depth_tensor
def blend_and_align_with_adabins(self, depth_tensor, adabins_depth, midas_weight):
depth_tensor = torch.subtract(50.0, depth_tensor) / 19.0 # align midas depth with adabins depth. Original alignment code from Disco Diffusion
blended_depth_map = (depth_tensor.cpu().numpy() * midas_weight + adabins_depth * (1.0 - midas_weight))
depth_tensor = torch.from_numpy(np.expand_dims(blended_depth_map, axis=0)).squeeze().to(self.device)
debug_print(f"Blended Midas Depth with AdaBins Depth")
return depth_tensor
def to(self, device):
self.device = device
if self.depth_algorithm.lower().startswith('zoe'):
self.zoe_depth.zoe.to(device)
elif self.depth_algorithm.lower() == 'leres':
self.leres_depth.to(device)
elif self.depth_algorithm.lower().startswith('midas'):
self.midas_depth.to(device)
if hasattr(self, 'adabins_model'):
self.adabins_model.to(device)
gc.collect()
torch.cuda.empty_cache()
def to_image(self, depth: torch.Tensor):
depth = depth.cpu().numpy()
depth = np.expand_dims(depth, axis=0) if len(depth.shape) == 2 else depth
self.depth_min, self.depth_max = min(self.depth_min, depth.min()), max(self.depth_max, depth.max())
denom = max(1e-8, self.depth_max - self.depth_min)
temp = rearrange((depth - self.depth_min) / denom * 255, 'c h w -> h w c')
return Image.fromarray(repeat(temp, 'h w 1 -> h w c', c=3).astype(np.uint8))
def save(self, filename: str, depth: torch.Tensor):
self.to_image(depth).save(filename)
def delete_model(self):
for attr in ['zoe_depth', 'leres_depth']:
if hasattr(self, attr):
getattr(self, attr).delete()
delattr(self, attr)
if hasattr(self, 'midas_depth'):
del self.midas_depth
if hasattr(self, 'adabins_model'):
self.adabins_model.delete_model()
gc.collect()
torch.cuda.empty_cache()
devices.torch_gc() |