GenPercept / pipeline_genpercept.py
guangkaixu's picture
upload
e487722
raw
history blame
No virus
12.8 kB
# --------------------------------------------------------
# Diffusion Models Trained with Large Data Are Transferable Visual Models (https://arxiv.org/abs/2403.06090)
# Github source: https://github.com/aim-uofa/GenPercept
# Copyright (c) 2024 Zhejiang University
# Licensed under The CC0 1.0 License [see LICENSE for details]
# By Guangkai Xu
# Based on Marigold, diffusers codebases
# https://github.com/prs-eth/marigold
# https://github.com/huggingface/diffusers
# --------------------------------------------------------
import torch
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from PIL import Image
from typing import List, Dict, Union
from torch.utils.data import DataLoader, TensorDataset
from diffusers import (
DiffusionPipeline,
UNet2DConditionModel,
AutoencoderKL,
)
from diffusers.utils import BaseOutput
from util.image_util import chw2hwc, colorize_depth_maps, resize_max_res, norm_to_rgb, resize_res
from util.batchsize import find_batch_size
class GenPerceptOutput(BaseOutput):
pred_np: np.ndarray
pred_colored: Image.Image
class GenPerceptPipeline(DiffusionPipeline):
vae_scale_factor = 0.18215
task_infos = {
'depth': dict(task_channel_num=1, interpolate='bilinear', ),
'seg': dict(task_channel_num=3, interpolate='nearest', ),
'sr': dict(task_channel_num=3, interpolate='nearest', ),
'normal': dict(task_channel_num=3, interpolate='bilinear', ),
}
def __init__(
self,
unet: UNet2DConditionModel,
vae: AutoencoderKL,
customized_head=None,
empty_text_embed=None,
):
super().__init__()
self.empty_text_embed = empty_text_embed
# register
register_dict = dict(
unet=unet,
vae=vae,
customized_head=customized_head,
)
self.register_modules(**register_dict)
@torch.no_grad()
def __call__(
self,
input_image: Union[Image.Image, torch.Tensor],
mode: str = 'depth',
resize_hard = False,
processing_res: int = 768,
match_input_res: bool = True,
batch_size: int = 0,
color_map: str = "Spectral",
show_progress_bar: bool = True,
) -> GenPerceptOutput:
"""
Function invoked when calling the pipeline.
Args:
input_image (Image):
Input RGB (or gray-scale) image.
processing_res (int, optional):
Maximum resolution of processing.
If set to 0: will not resize at all.
Defaults to 768.
match_input_res (bool, optional):
Resize depth prediction to match input resolution.
Only valid if `limit_input_res` is not None.
Defaults to True.
batch_size (int, optional):
Inference batch size.
If set to 0, the script will automatically decide the proper batch size.
Defaults to 0.
show_progress_bar (bool, optional):
Display a progress bar of diffusion denoising.
Defaults to True.
color_map (str, optional):
Colormap used to colorize the depth map.
Defaults to "Spectral".
Returns:
`GenPerceptOutput`
"""
device = self.device
task_channel_num = self.task_infos[mode]['task_channel_num']
if not match_input_res:
assert (
processing_res is not None
), "Value error: `resize_output_back` is only valid with "
assert processing_res >= 0
# ----------------- Image Preprocess -----------------
if type(input_image) == torch.Tensor: # [B, 3, H, W]
rgb_norm = input_image.to(device)
input_size = input_image.shape[2:]
bs_imgs = rgb_norm.shape[0]
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
rgb_norm = rgb_norm.to(self.dtype)
else:
# if len(rgb_paths) > 0 and 'kitti' in rgb_paths[0]:
# # kb crop
# height = input_image.size[1]
# width = input_image.size[0]
# top_margin = int(height - 352)
# left_margin = int((width - 1216) / 2)
# input_image = input_image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
# TODO: check the kitti evaluation resolution here.
input_size = (input_image.size[1], input_image.size[0])
# Resize image
if processing_res > 0:
if resize_hard:
input_image = resize_res(
input_image, max_edge_resolution=processing_res
)
else:
input_image = resize_max_res(
input_image, max_edge_resolution=processing_res
)
input_image = input_image.convert("RGB")
image = np.asarray(input_image)
# Normalize rgb values
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
rgb_norm = rgb / 255.0 * 2.0 - 1.0
rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
rgb_norm = rgb_norm[None].to(device)
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
bs_imgs = 1
# ----------------- Predicting depth -----------------
single_rgb_dataset = TensorDataset(rgb_norm)
if batch_size > 0:
_bs = batch_size
else:
_bs = find_batch_size(
ensemble_size=1,
input_res=max(rgb_norm.shape[1:]),
dtype=self.dtype,
)
single_rgb_loader = DataLoader(
single_rgb_dataset, batch_size=_bs, shuffle=False
)
# Predict depth maps (batched)
pred_list = []
if show_progress_bar:
iterable = tqdm(
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
)
else:
iterable = single_rgb_loader
for batch in iterable:
(batched_img, ) = batch
pred = self.single_infer(
rgb_in=batched_img,
mode=mode,
)
pred_list.append(pred.detach().clone())
preds = torch.concat(pred_list, axis=0).squeeze() # [bs_imgs, task_channel_num, H, W]
preds = preds.view(bs_imgs, task_channel_num, preds.shape[-2], preds.shape[-1])
if match_input_res:
preds = F.interpolate(preds, input_size, mode=self.task_infos[mode]['interpolate'])
# ----------------- Post processing -----------------
if mode == 'depth':
if len(preds.shape) == 4:
preds = preds[:, 0] # [bs_imgs, H, W]
# Scale prediction to [0, 1]
min_d = preds.view(bs_imgs, -1).min(dim=1)[0]
max_d = preds.view(bs_imgs, -1).max(dim=1)[0]
preds = (preds - min_d[:, None, None]) / (max_d[:, None, None] - min_d[:, None, None])
preds = preds.cpu().numpy().astype(np.float32)
# Colorize
pred_colored_img_list = []
for i in range(bs_imgs):
pred_colored_chw = colorize_depth_maps(
preds[i], 0, 1, cmap=color_map
).squeeze() # [3, H, W], value in (0, 1)
pred_colored_chw = (pred_colored_chw * 255).astype(np.uint8)
pred_colored_hwc = chw2hwc(pred_colored_chw)
pred_colored_img = Image.fromarray(pred_colored_hwc)
pred_colored_img_list.append(pred_colored_img)
return GenPerceptOutput(
pred_np=np.squeeze(preds),
pred_colored=pred_colored_img_list[0] if len(pred_colored_img_list) == 1 else pred_colored_img_list,
)
elif mode == 'seg' or mode == 'sr':
if not self.customized_head:
# shift to [0, 1]
preds = (preds + 1.0) / 2.0
# shift to [0, 255]
preds = preds * 255
# Clip output range
preds = preds.clip(0, 255).cpu().numpy().astype(np.uint8)
else:
raise NotImplementedError
pred_colored_img_list = []
for i in range(preds.shape[0]):
pred_colored_hwc = chw2hwc(preds[i])
pred_colored_img = Image.fromarray(pred_colored_hwc)
pred_colored_img_list.append(pred_colored_img)
return GenPerceptOutput(
pred_np=np.squeeze(preds),
pred_colored=pred_colored_img_list[0] if len(pred_colored_img_list) == 1 else pred_colored_img_list,
)
elif mode == 'normal':
if not self.customized_head:
preds = preds.clip(-1, 1).cpu().numpy() # [-1, 1]
else:
raise NotImplementedError
pred_colored_img_list = []
for i in range(preds.shape[0]):
pred_colored_chw = norm_to_rgb(preds[i])
pred_colored_hwc = chw2hwc(pred_colored_chw)
normal_colored_img_i = Image.fromarray(pred_colored_hwc)
pred_colored_img_list.append(normal_colored_img_i)
return GenPerceptOutput(
pred_np=np.squeeze(preds),
pred_colored=pred_colored_img_list[0] if len(pred_colored_img_list) == 1 else pred_colored_img_list,
)
else:
raise NotImplementedError
@torch.no_grad()
def single_infer(
self,
rgb_in: torch.Tensor,
mode: str = 'depth',
) -> torch.Tensor:
"""
Perform an individual depth prediction without ensembling.
Args:
rgb_in (torch.Tensor):
Input RGB image.
num_inference_steps (int):
Number of diffusion denoising steps (DDIM) during inference.
show_pbar (bool):
Display a progress bar of diffusion denoising.
Returns:
torch.Tensor: Predicted depth map.
"""
device = rgb_in.device
bs_imgs = rgb_in.shape[0]
timesteps = torch.tensor([1]).long().repeat(bs_imgs).to(device)
# Encode image
rgb_latent = self.encode_rgb(rgb_in)
batch_embed = self.empty_text_embed
batch_embed = batch_embed.repeat((rgb_latent.shape[0], 1, 1)).to(device) # [bs_imgs, 77, 1024]
# Forward!
if self.customized_head:
unet_features = self.unet(rgb_latent, timesteps, encoder_hidden_states=batch_embed, return_feature_only=True)[0][::-1]
pred = self.customized_head(unet_features)
else:
unet_output = self.unet(
rgb_latent, timesteps, encoder_hidden_states=batch_embed
) # [bs_imgs, 4, h, w]
unet_pred = unet_output.sample
pred_latent = - unet_pred
pred_latent.to(device)
pred = self.decode_pred(pred_latent)
if mode == 'depth':
# mean of output channels
pred = pred.mean(dim=1, keepdim=True)
# clip prediction
pred = torch.clip(pred, -1.0, 1.0)
return pred
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
"""
Encode RGB image into latent.
Args:
rgb_in (torch.Tensor):
Input RGB image to be encoded.
Returns:
torch.Tensor: Image latent
"""
try:
# encode
h_temp = self.vae.encoder(rgb_in)
moments = self.vae.quant_conv(h_temp)
except:
# encode
h_temp = self.vae.encoder(rgb_in.float())
moments = self.vae.quant_conv(h_temp.float())
mean, logvar = torch.chunk(moments, 2, dim=1)
# scale latent
rgb_latent = mean * self.vae_scale_factor
return rgb_latent
def decode_pred(self, pred_latent: torch.Tensor) -> torch.Tensor:
"""
Decode pred latent into pred label.
Args:
pred_latent (torch.Tensor):
prediction latent to be decoded.
Returns:
torch.Tensor: Decoded prediction label.
"""
# scale latent
pred_latent = pred_latent / self.vae_scale_factor
# decode
z = self.vae.post_quant_conv(pred_latent)
pred = self.vae.decoder(z)
return pred