ddoc's picture
Upload 188 files
81f4d3a
# Copyright (C) 2023 Deforum LLC
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
# Contact the authors: https://deforum.github.io/
import gc
import cv2
import numpy as np
import torch
from PIL import Image
from einops import rearrange, repeat
from modules import devices
from modules.shared import cmd_opts
from .depth_adabins import AdaBinsModel
from .depth_leres import LeReSDepth
from .depth_midas import MidasDepth
from .depth_zoe import ZoeDepth
from .general_utils import debug_print
class DepthModel:
_instance = None
def __new__(cls, *args, **kwargs):
keep_in_vram = kwargs.get('keep_in_vram', False)
depth_algorithm = kwargs.get('depth_algorithm', 'Midas-3-Hybrid')
Width, Height = kwargs.get('Width', 512), kwargs.get('Height', 512)
midas_weight = kwargs.get('midas_weight', 0.2)
model_switched = cls._instance and cls._instance.depth_algorithm != depth_algorithm
resolution_changed = cls._instance and (cls._instance.Width != Width or cls._instance.Height != Height)
zoe_algorithm = 'zoe' in depth_algorithm.lower()
model_deleted = cls._instance and cls._instance.should_delete
should_reload = (cls._instance is None or model_deleted or model_switched or (zoe_algorithm and resolution_changed))
if should_reload:
cls._instance = super().__new__(cls)
cls._instance._initialize(models_path=args[0], device=args[1], half_precision=not cmd_opts.no_half, keep_in_vram=keep_in_vram, depth_algorithm=depth_algorithm, Width=Width, Height=Height, midas_weight=midas_weight)
elif cls._instance.should_delete and keep_in_vram:
cls._instance._initialize(models_path=args[0], device=args[1], half_precision=not cmd_opts.no_half, keep_in_vram=keep_in_vram, depth_algorithm=depth_algorithm, Width=Width, Height=Height, midas_weight=midas_weight)
cls._instance.should_delete = not keep_in_vram
return cls._instance
def _initialize(self, models_path, device, half_precision=not cmd_opts.no_half, keep_in_vram=False, depth_algorithm='Midas-3-Hybrid', Width=512, Height=512, midas_weight=1.0):
self.models_path = models_path
self.device = device
self.half_precision = half_precision
self.keep_in_vram = keep_in_vram
self.depth_algorithm = depth_algorithm
self.Width, self.Height = Width, Height
self.midas_weight = midas_weight
self.depth_min, self.depth_max = 1000, -1000
self.adabins_helper = None
self._initialize_model()
def _initialize_model(self):
depth_algo = self.depth_algorithm.lower()
if depth_algo.startswith('zoe'):
self.zoe_depth = ZoeDepth(self.Width, self.Height)
if depth_algo == 'zoe+adabins (old)':
self.adabins_model = AdaBinsModel(self.models_path, keep_in_vram=self.keep_in_vram)
self.adabins_helper = self.adabins_model.adabins_helper
elif depth_algo == 'leres':
self.leres_depth = LeReSDepth(width=448, height=448, models_path=self.models_path, checkpoint_name='res101.pth', backbone='resnext101')
elif depth_algo == 'adabins':
self.adabins_model = AdaBinsModel(self.models_path, keep_in_vram=self.keep_in_vram)
self.adabins_helper = self.adabins_model.adabins_helper
elif depth_algo.startswith('midas'):
self.midas_depth = MidasDepth(self.models_path, self.device, half_precision=self.half_precision, midas_model_type=self.depth_algorithm)
if depth_algo == 'midas+adabins (old)':
self.adabins_model = AdaBinsModel(self.models_path, keep_in_vram=self.keep_in_vram)
self.adabins_helper = self.adabins_model.adabins_helper
else:
raise Exception(f"Unknown depth_algorithm: {self.depth_algorithm}")
def predict(self, prev_img_cv2, midas_weight, half_precision) -> torch.Tensor:
img_pil = Image.fromarray(cv2.cvtColor(prev_img_cv2.astype(np.uint8), cv2.COLOR_RGB2BGR))
if self.depth_algorithm.lower().startswith('zoe'):
depth_tensor = self.zoe_depth.predict(img_pil).to(self.device)
if self.depth_algorithm.lower() == 'zoe+adabins (old)' and midas_weight < 1.0:
use_adabins, adabins_depth = AdaBinsModel._instance.predict(img_pil, prev_img_cv2)
if use_adabins: # if there was no error in getting the adabins depth, align midas with adabins
depth_tensor = self.blend_and_align_with_adabins(depth_tensor, adabins_depth, midas_weight)
elif self.depth_algorithm.lower() == 'leres':
depth_tensor = self.leres_depth.predict(prev_img_cv2.astype(np.float32) / 255.0)
elif self.depth_algorithm.lower() == 'adabins':
use_adabins, adabins_depth = AdaBinsModel._instance.predict(img_pil, prev_img_cv2)
depth_tensor = torch.tensor(adabins_depth)
if use_adabins is False:
raise Exception("Error getting depth from AdaBins") # TODO: fallback to something else maybe?
elif self.depth_algorithm.lower().startswith('midas'):
depth_tensor = self.midas_depth.predict(prev_img_cv2, half_precision)
if self.depth_algorithm.lower() == 'midas+adabins (old)' and midas_weight < 1.0:
use_adabins, adabins_depth = AdaBinsModel._instance.predict(img_pil, prev_img_cv2)
if use_adabins: # if there was no error in getting the adabins depth, align midas with adabins
depth_tensor = self.blend_and_align_with_adabins(depth_tensor, adabins_depth, midas_weight)
else: # Unknown!
raise Exception(f"Unknown depth_algorithm passed to depth.predict function: {self.depth_algorithm}")
return depth_tensor
def blend_and_align_with_adabins(self, depth_tensor, adabins_depth, midas_weight):
depth_tensor = torch.subtract(50.0, depth_tensor) / 19.0 # align midas depth with adabins depth. Original alignment code from Disco Diffusion
blended_depth_map = (depth_tensor.cpu().numpy() * midas_weight + adabins_depth * (1.0 - midas_weight))
depth_tensor = torch.from_numpy(np.expand_dims(blended_depth_map, axis=0)).squeeze().to(self.device)
debug_print(f"Blended Midas Depth with AdaBins Depth")
return depth_tensor
def to(self, device):
self.device = device
if self.depth_algorithm.lower().startswith('zoe'):
self.zoe_depth.zoe.to(device)
elif self.depth_algorithm.lower() == 'leres':
self.leres_depth.to(device)
elif self.depth_algorithm.lower().startswith('midas'):
self.midas_depth.to(device)
if hasattr(self, 'adabins_model'):
self.adabins_model.to(device)
gc.collect()
torch.cuda.empty_cache()
def to_image(self, depth: torch.Tensor):
depth = depth.cpu().numpy()
depth = np.expand_dims(depth, axis=0) if len(depth.shape) == 2 else depth
self.depth_min, self.depth_max = min(self.depth_min, depth.min()), max(self.depth_max, depth.max())
denom = max(1e-8, self.depth_max - self.depth_min)
temp = rearrange((depth - self.depth_min) / denom * 255, 'c h w -> h w c')
return Image.fromarray(repeat(temp, 'h w 1 -> h w c', c=3).astype(np.uint8))
def save(self, filename: str, depth: torch.Tensor):
self.to_image(depth).save(filename)
def delete_model(self):
for attr in ['zoe_depth', 'leres_depth']:
if hasattr(self, attr):
getattr(self, attr).delete()
delattr(self, attr)
if hasattr(self, 'midas_depth'):
del self.midas_depth
if hasattr(self, 'adabins_model'):
self.adabins_model.delete_model()
gc.collect()
torch.cuda.empty_cache()
devices.torch_gc()