# A reimplemented version in public environments by Xiao Fu and Mu Hu from typing import Any, Dict, Union import torch from torch.utils.data import DataLoader, TensorDataset import numpy as np from tqdm.auto import tqdm from PIL import Image from diffusers import ( DiffusionPipeline, DDIMScheduler, AutoencoderKL, ) from models.unet_2d_condition import UNet2DConditionModel from diffusers.utils import BaseOutput from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection import torchvision.transforms.functional as TF from torchvision.transforms import InterpolationMode from utils.image_util import resize_max_res,chw2hwc,colorize_depth_maps from utils.colormap import kitti_colormap from utils.depth_ensemble import ensemble_depths from utils.normal_ensemble import ensemble_normals from utils.batch_size import find_batch_size import cv2 class DepthNormalPipelineOutput(BaseOutput): """ Output class for Marigold monocular depth prediction pipeline. Args: depth_np (`np.ndarray`): Predicted depth map, with depth values in the range of [0, 1]. depth_colored (`PIL.Image.Image`): Colorized depth map, with the shape of [3, H, W] and values in [0, 1]. normal_np (`np.ndarray`): Predicted normal map, with depth values in the range of [0, 1]. normal_colored (`PIL.Image.Image`): Colorized normal map, with the shape of [3, H, W] and values in [0, 1]. uncertainty (`None` or `np.ndarray`): Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. """ depth_np: np.ndarray depth_colored: Image.Image normal_np: np.ndarray normal_colored: Image.Image uncertainty: Union[None, np.ndarray] class DepthNormalEstimationPipeline(DiffusionPipeline): # two hyper-parameters latent_scale_factor = 0.18215 def __init__(self, unet:UNet2DConditionModel, vae:AutoencoderKL, scheduler:DDIMScheduler, image_encoder:CLIPVisionModelWithProjection, feature_extractor:CLIPImageProcessor, ): super().__init__() self.register_modules( unet=unet, vae=vae, scheduler=scheduler, image_encoder=image_encoder, feature_extractor=feature_extractor, ) self.img_embed = None @torch.no_grad() def __call__(self, input_image:Image, denoising_steps: int = 10, ensemble_size: int = 10, processing_res: int = 768, match_input_res:bool =True, batch_size:int = 0, domain: str = "indoor", color_map: str="Spectral", show_progress_bar:bool = True, ensemble_kwargs: Dict = None, ) -> DepthNormalPipelineOutput: # inherit from thea Diffusion Pipeline device = self.device input_size = input_image.size # adjust the input resolution. if not match_input_res: assert ( processing_res is not None )," Value Error: `resize_output_back` is only valid with " assert processing_res >=0 assert denoising_steps >=1 assert ensemble_size >=1 # --------------- Image Processing ------------------------ # Resize image if processing_res >0: input_image = resize_max_res( input_image, max_edge_resolution=processing_res ) # Convert the image to RGB, to 1. reomve the alpha channel. input_image = input_image.convert("RGB") image = np.array(input_image) # Normalize RGB Values. rgb = np.transpose(image,(2,0,1)) rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype) rgb_norm = rgb_norm.to(device) assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 # ----------------- predicting depth ----------------- duplicated_rgb = torch.stack([rgb_norm] * ensemble_size) single_rgb_dataset = TensorDataset(duplicated_rgb) # find the batch size if batch_size>0: _bs = batch_size else: _bs = 1 single_rgb_loader = DataLoader(single_rgb_dataset, batch_size=_bs, shuffle=False) # predicted the depth depth_pred_ls = [] normal_pred_ls = [] if show_progress_bar: iterable_bar = tqdm( single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False ) else: iterable_bar = single_rgb_loader for batch in iterable_bar: (batched_image, )= batch # here the image is still around 0-1 depth_pred_raw, normal_pred_raw = self.single_infer( input_rgb=batched_image, num_inference_steps=denoising_steps, domain=domain, show_pbar=show_progress_bar, ) depth_pred_ls.append(depth_pred_raw.detach().clone()) normal_pred_ls.append(normal_pred_raw.detach().clone()) depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze() #(10,224,768) normal_preds = torch.concat(normal_pred_ls, axis=0).squeeze() torch.cuda.empty_cache() # clear vram cache for ensembling # ----------------- Test-time ensembling ----------------- if ensemble_size > 1: depth_pred, pred_uncert = ensemble_depths( depth_preds, **(ensemble_kwargs or {}) ) normal_pred = ensemble_normals(normal_preds) else: depth_pred = depth_preds normal_pred = normal_preds pred_uncert = None # ----------------- Post processing ----------------- # Scale prediction to [0, 1] min_d = torch.min(depth_pred) max_d = torch.max(depth_pred) depth_pred = (depth_pred - min_d) / (max_d - min_d) # Convert to numpy depth_pred = depth_pred.cpu().numpy().astype(np.float32) normal_pred = normal_pred.cpu().numpy().astype(np.float32) # Resize back to original resolution if match_input_res: pred_img = Image.fromarray(depth_pred) pred_img = pred_img.resize(input_size) depth_pred = np.asarray(pred_img) normal_pred = cv2.resize(chw2hwc(normal_pred), input_size, interpolation = cv2.INTER_NEAREST) # Clip output range: current size is the original size depth_pred = depth_pred.clip(0, 1) normal_pred = normal_pred.clip(-1, 1) # Colorize depth_colored = colorize_depth_maps( depth_pred, 0, 1, cmap=color_map ).squeeze() # [3, H, W], value in (0, 1) depth_colored = (depth_colored * 255).astype(np.uint8) depth_colored_hwc = chw2hwc(depth_colored) depth_colored_img = Image.fromarray(depth_colored_hwc) normal_colored = ((normal_pred + 1)/2 * 255).astype(np.uint8) normal_colored_img = Image.fromarray(normal_colored) self.img_embed = None return DepthNormalPipelineOutput( depth_np = depth_pred, depth_colored = depth_colored_img, normal_np = normal_pred, normal_colored = normal_colored_img, uncertainty=pred_uncert, ) def __encode_img_embed(self, rgb): """ Encode clip embeddings for img """ clip_image_mean = torch.as_tensor(self.feature_extractor.image_mean)[:,None,None].to(device=self.device, dtype=self.dtype) clip_image_std = torch.as_tensor(self.feature_extractor.image_std)[:,None,None].to(device=self.device, dtype=self.dtype) img_in_proc = TF.resize((rgb +1)/2, (self.feature_extractor.crop_size['height'], self.feature_extractor.crop_size['width']), interpolation=InterpolationMode.BICUBIC, antialias=True ) # do the normalization in float32 to preserve precision img_in_proc = ((img_in_proc.float() - clip_image_mean) / clip_image_std).to(self.dtype) img_embed = self.image_encoder(img_in_proc).image_embeds.unsqueeze(1).to(self.dtype) self.img_embed = img_embed @torch.no_grad() def single_infer(self,input_rgb:torch.Tensor, num_inference_steps:int, domain:str, show_pbar:bool,): device = input_rgb.device # Set timesteps: inherit from the diffuison pipeline self.scheduler.set_timesteps(num_inference_steps, device=device) # here the numbers of the steps is only 10. timesteps = self.scheduler.timesteps # [T] # encode image rgb_latent = self.encode_RGB(input_rgb) # Initial geometric maps (Guassian noise) geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1) rgb_latent = rgb_latent.repeat(2,1,1,1) # Batched img embedding if self.img_embed is None: self.__encode_img_embed(input_rgb) batch_img_embed = self.img_embed.repeat( (rgb_latent.shape[0], 1, 1) ) # [B, 1, 768] # hybrid switcher geo_class = torch.tensor([[0., 1.], [1, 0]], device=device, dtype=self.dtype) geo_embedding = torch.cat([torch.sin(geo_class), torch.cos(geo_class)], dim=-1) if domain == "indoor": domain_class = torch.tensor([[1., 0., 0]], device=device, dtype=self.dtype).repeat(2,1) elif domain == "outdoor": domain_class = torch.tensor([[0., 1., 0]], device=device, dtype=self.dtype).repeat(2,1) elif domain == "object": domain_class = torch.tensor([[0., 0., 1]], device=device, dtype=self.dtype).repeat(2,1) domain_embedding = torch.cat([torch.sin(domain_class), torch.cos(domain_class)], dim=-1) class_embedding = torch.cat((geo_embedding, domain_embedding), dim=-1) # Denoising loop if show_pbar: iterable = tqdm( enumerate(timesteps), total=len(timesteps), leave=False, desc=" " * 4 + "Diffusion denoising", ) else: iterable = enumerate(timesteps) for i, t in iterable: unet_input = torch.cat([rgb_latent, geo_latent], dim=1) # predict the noise residual noise_pred = self.unet( unet_input, t.repeat(2), encoder_hidden_states=batch_img_embed, class_labels=class_embedding ).sample # [B, 4, h, w] # compute the previous noisy sample x_t -> x_t-1 geo_latent = self.scheduler.step(noise_pred, t, geo_latent).prev_sample geo_latent = geo_latent torch.cuda.empty_cache() depth = self.decode_depth(geo_latent[0][None]) depth = torch.clip(depth, -1.0, 1.0) depth = (depth + 1.0) / 2.0 normal = self.decode_normal(geo_latent[1][None]) normal /= (torch.norm(normal, p=2, dim=1, keepdim=True)+1e-5) normal *= -1. return depth, normal def encode_RGB(self, rgb_in: torch.Tensor) -> torch.Tensor: """ Encode RGB image into latent. Args: rgb_in (`torch.Tensor`): Input RGB image to be encoded. Returns: `torch.Tensor`: Image latent. """ # encode h = self.vae.encoder(rgb_in) moments = self.vae.quant_conv(h) mean, logvar = torch.chunk(moments, 2, dim=1) # scale latent rgb_latent = mean * self.latent_scale_factor return rgb_latent def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: """ Decode depth latent into depth map. Args: depth_latent (`torch.Tensor`): Depth latent to be decoded. Returns: `torch.Tensor`: Decoded depth map. """ # scale latent depth_latent = depth_latent / self.latent_scale_factor # decode z = self.vae.post_quant_conv(depth_latent) stacked = self.vae.decoder(z) # mean of output channels depth_mean = stacked.mean(dim=1, keepdim=True) return depth_mean def decode_normal(self, normal_latent: torch.Tensor) -> torch.Tensor: """ Decode normal latent into normal map. Args: normal_latent (`torch.Tensor`): Depth latent to be decoded. Returns: `torch.Tensor`: Decoded normal map. """ # scale latent normal_latent = normal_latent / self.latent_scale_factor # decode z = self.vae.post_quant_conv(normal_latent) normal = self.vae.decoder(z) return normal