import math import numpy as np from omegaconf import OmegaConf from pathlib import Path import cv2 import torch import torch.nn as nn import torch.nn.functional as F from torch.cuda.amp import custom_bwd, custom_fwd from torchvision.utils import save_image from torchvision.ops import masks_to_boxes from torchvision.transforms import Resize from diffusers import DDIMScheduler, DDPMScheduler from einops import rearrange, repeat from tqdm import tqdm import sys from os import path sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) sys.path.append("./models/") from loguru import logger from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.modules.diffusionmodules.util import extract_into_tensor # load model def load_model_from_config(config, ckpt, device, vram_O=False, verbose=True): pl_sd = torch.load(ckpt, map_location='cpu') if 'global_step' in pl_sd and verbose: logger.info(f'Global Step: {pl_sd["global_step"]}') sd = pl_sd['state_dict'] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0: logger.warning('missing keys: \n', m) if len(u) > 0: logger.warning('unexpected keys: \n', u) # manually load ema and delete it to save GPU memory if model.use_ema: logger.debug('loading EMA...') model.model_ema.copy_to(model.model) del model.model_ema if vram_O: # we don't need decoder del model.first_stage_model.decoder torch.cuda.empty_cache() model.eval().to(device) # model.first_stage_model.train = True # model.first_stage_model.train() for param in model.first_stage_model.parameters(): param.requires_grad = True return model class MateralDiffusion(nn.Module): def __init__(self, device, fp16, config=None, ckpt=None, vram_O=False, t_range=[0.02, 0.98], opt=None, use_ddim=True): super().__init__() self.device = device self.fp16 = fp16 self.vram_O = vram_O self.t_range = t_range self.opt = opt self.config = OmegaConf.load(config) # TODO: seems it cannot load into fp16... self.model = load_model_from_config(self.config, ckpt, device=self.device, vram_O=vram_O, verbose=True) # timesteps: use diffuser for convenience... hope it's alright. self.num_train_timesteps = self.config.model.params.timesteps self.use_ddim = use_ddim if self.use_ddim: self.scheduler = DDIMScheduler( self.num_train_timesteps, self.config.model.params.linear_start, self.config.model.params.linear_end, beta_schedule='scaled_linear', clip_sample=False, set_alpha_to_one=False, steps_offset=1, ) print("Using DDIM...") else: self.scheduler = DDPMScheduler( self.num_train_timesteps, self.config.model.params.linear_start, self.config.model.params.linear_end, beta_schedule='scaled_linear', clip_sample=False, ) print("Using DDPM...") self.min_step = int(self.num_train_timesteps * t_range[0]) self.max_step = int(self.num_train_timesteps * t_range[1]) self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience def get_input(self, x): if len(x.shape) == 3: x = x[..., None] x = rearrange(x, 'b h w c -> b c h w') x = x.to(memory_format=torch.contiguous_format).float() return x def center_crop(self, img, mask, return_uv=False, mask_ratio=.8, image_size=256): margin = np.round((1 - mask_ratio) * image_size).astype(int) resizer = Resize([np.round(image_size-margin*2).astype(int), np.round(image_size-margin*2).astype(int)]) # img ~ batch, h, w, 3 # mask ~ batch, h, w, 3 # ensure border is 0, as grid sampler only support border or zeros padding # But we need the one padding batch_size = img.shape[0] min_max_uv = masks_to_boxes(mask[..., -1] > 0.5) min_uv, max_uv = min_max_uv[..., [1,0]].long(), (min_max_uv[..., [3,2]] + 1).long() # fill back ground to ones img = (img + (mask[..., -1:] <= 0.5)).clamp(0, 1) img = rearrange(img, 'b h w c -> b c h w') ori_size = torch.tensor(img.shape[-2:]).to(min_max_uv.device).reshape(1, 2).expand(img.shape[0], -1) crooped_imgs = [] for batch_idx in range(batch_size): # print(min_uv, max_uv, margin) img_crop = img[batch_idx][:, min_uv[batch_idx, 0]:max_uv[batch_idx, 0], min_uv[batch_idx,1]:max_uv[batch_idx, 1]] img_crop = resizer(img_crop) img_out = torch.ones(3, image_size, image_size).to(img.device) img_out[:, margin:image_size-margin, margin:image_size-margin] = img_crop crooped_imgs.append(img_out) img_new = torch.stack(crooped_imgs, dim=0) img_new = rearrange(img_new, 'b c h w -> b h w c') crop_uv = torch.stack([ori_size[:, 0], ori_size[:, 1], min_uv[:, 0], min_uv[:, 1], max_uv[:, 0], max_uv[:, 1], max_uv[:, 1]*0+margin], dim=-1).float() if return_uv: return img_new, crop_uv return img_new def center_crop_aspect_ratio(self, img, mask, return_uv=False, mask_ratio=.8, image_size=256): # img ~ batch, h, w, 3 # mask ~ batch, h, w, 3 # ensure border is 0, as grid sampler only support border or zeros padding # But we need the one padding boarder_mask = torch.zeros_like(mask) boarder_mask[:, 1:-1, 1:-1] = 1 mask = mask * boarder_mask # print(f"mask: {mask.shape}, {(mask[..., -1] > 0.5).sum}") min_max_uv = masks_to_boxes(mask[..., -1] > 0.5) min_uv, max_uv = min_max_uv[..., [1,0]], min_max_uv[..., [3,2]] # fill back ground to ones img = (img + (mask[..., -1:] <= 0.5)).clamp(0, 1) img = rearrange(img, 'b h w c -> b c h w') ori_size = torch.tensor(img.shape[-2:]).to(min_max_uv.device).reshape(1, 2).expand(img.shape[0], -1) crop_length = torch.div((max_uv - min_uv), 2, rounding_mode='floor') half_size = torch.max(crop_length, dim=-1, keepdim=True)[0] center_uv = min_uv + crop_length # generate grid target_size = image_size grid_x, grid_y = torch.meshgrid(torch.arange(0, target_size, 1, device=min_max_uv.device), \ torch.arange(0, target_size, 1, device=min_max_uv.device), \ indexing='ij') normalized_xy = torch.stack([(grid_x) / (target_size - 1), grid_y / (target_size - 1)], dim=-1) # [0,1] normalized_xy = (normalized_xy - 0.5) / mask_ratio + 0.5 normalized_xy = normalized_xy[None].expand(img.shape[0], -1, -1, -1) ori_crop_size = 2 * half_size + 1 xy_scale = (ori_crop_size-1) / (ori_size - 1) normalized_xy = normalized_xy * xy_scale.reshape(-1, 1, 1, 2)[..., [0,1]] xy_shift = (center_uv - half_size) / (ori_size - 1) normalized_xy = normalized_xy + xy_shift.reshape(-1, 1, 1, 2)[..., [0,1]] normalized_xy = normalized_xy * 2 - 1 # [-1,1] # normalized_xy = normalized_xy / mask_ratio img_new = F.grid_sample(img, normalized_xy[..., [1,0]], padding_mode='border', align_corners=True) crop_uv = torch.stack([ori_size[:, 0], ori_size[:, 1], half_size[..., 0]*0.0 + mask_ratio, half_size[..., 0], center_uv[:, 0], center_uv[:, 1]], dim=-1).float() img_new = rearrange(img_new, 'b c h w -> b h w c') if return_uv: return img_new, crop_uv return img_new def restore_crop(self, img, img_ori, crop_idx): ori_size, min_uv, max_uv, margin = crop_idx[:, :2].long(), crop_idx[:, 2:4].long(), crop_idx[:, 4:6].long(), crop_idx[0, 6].long().item() batch_size = img.shape[0] all_images = [] for batch_idx in range(batch_size): img_out = torch.ones(3, ori_size[batch_idx][0], ori_size[batch_idx][1]).to(img.device) cropped_size = max_uv[batch_idx] - min_uv[batch_idx] resizer = Resize([cropped_size[0], cropped_size[1]]) net_size = img[batch_idx].shape[-1] img_crop = resizer(img[batch_idx][:, margin:net_size-margin, margin:net_size-margin]) img_out[:, min_uv[batch_idx, 0]:max_uv[batch_idx, 0], min_uv[batch_idx,1]:max_uv[batch_idx, 1]] = img_crop all_images.append(img_out) all_images = torch.stack(all_images, dim=0) all_images = rearrange(all_images, 'b c h w -> b h w c') return all_images def restore_crop_aspect_ratio(self, img, img_ori, crop_idx): ori_size, mask_ratio, half_size, center_uv = crop_idx[:, :2].long(), crop_idx[:, 2:3], crop_idx[:, 3:4].long(), crop_idx[:, 4:].long() img[:, :, 0, :] = 1 img[:, :, -1, :] = 1 img[:, :, :, 0] = 1 img[:, :, :, -1] = 1 ori_crop_size = 2*half_size + 1 grid_x, grid_y = torch.meshgrid(torch.arange(0, ori_size[0, 0].item(), 1, device=img.device), \ torch.arange(0, ori_size[0, 1].item(), 1, device=img.device), \ indexing='ij') normalized_xy = torch.stack([grid_x, grid_y], dim=-1)[None].expand(img.shape[0], -1, -1, -1) - \ (center_uv - half_size).reshape(-1, 1, 1, 2)[..., [0,1]] normalized_xy = normalized_xy / (ori_crop_size-1).reshape(-1, 1, 1, 1) normalized_xy = (2*normalized_xy - 1) * mask_ratio.reshape(-1, 1, 1, 1) sample_start = (center_uv - half_size) # print(normalized_xy[0][sample_start[0][0], sample_start[0][1]], mask_ratio) img_out = F.grid_sample(img, normalized_xy[..., [1,0]], padding_mode='border', align_corners=True) img_out = rearrange(img_out, 'b c h w -> b h w c') return img_out def _image2diffusion(self, embeddings, pred_rgb, mask, image_size=256): # pred_rgb: tensor [1, 3, H, W] in [0, 1] # assert pred_rgb.w assert len(pred_rgb.shape) == 4, f"except 4 dim tensor, got: {pred_rgb.shape}" cond_img = embeddings["cond_img"] cond_img = self.center_crop(cond_img, mask, mask_ratio=1.0, image_size=image_size) pred_rgb_256, crop_idx_all = self.center_crop(pred_rgb, mask, return_uv=True, mask_ratio=1.0, image_size=image_size) # print(f"pred_rgb_256: {pred_rgb_256.min()} {pred_rgb_256.max()} {pred_rgb_256.shape} {cond_img.shape}") mask_img = self.center_crop(1 - mask.expand(-1, -1, -1, 3), mask, mask_ratio=1.0, image_size=image_size) xc = self.get_input(cond_img) pred_rgb_256 = self.get_input(pred_rgb_256) return pred_rgb_256, crop_idx_all, xc def _get_condition(self, xc, with_uncondition=False): # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%. # z.shape: [8, 4, 64, 64]; c.shape: [8, 1, 768] # print('=========== xc shape ===========', xc.shape) # print(xc.shape, xc.min(), xc.max(), self.model.use_clip_embdding) xc = xc * 2 - 1 cond = {} clip_emb = self.model.get_learned_conditioning(xc if self.model.use_clip_embdding else [""]).detach() c_concat = self.model.encode_first_stage((xc.to(self.device))).mode().detach() # print(clip_emb.shape, clip_emb.min(), clip_emb.max(), self.model.use_clip_embdding) if with_uncondition: cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)] cond['c_concat'] = [torch.cat([torch.zeros_like(c_concat).to(self.device), c_concat], dim=0)] else: cond['c_crossattn'] = [clip_emb] cond['c_concat'] = [c_concat] return cond @torch.no_grad() def __call__(self, embeddings, pred_rgb, mask, guidance_scale=3, dps_scale=0.2, as_latent=False, grad_scale=1, save_guidance_path:Path=None, ddim_steps=200, ddim_eta=1, operator=None): # todo: The upsacle is currectly hard-coded upscale = 1 # with torch.autocast(device_type="cuda", dtype=torch.bfloat16): pred_rgb_256, crop_idx_all, xc = self._image2diffusion(embeddings, pred_rgb, mask, image_size=256*upscale) cond = self._get_condition(xc, with_uncondition=True) assert pred_rgb_256.shape[-1] == pred_rgb_256.shape[-2], f"Expect image of square size, get {pred_rgb.shape}" latents = torch.randn_like(self.encode_imgs(pred_rgb_256)) if self.use_ddim: self.scheduler.set_timesteps(ddim_steps) else: self.scheduler.set_timesteps(self.num_train_timesteps) intermidates = [] for i, t in tqdm(enumerate(self.scheduler.timesteps)): x_in = torch.cat([latents] * 2) t_in = torch.cat([t.view(1).expand(latents.shape[0])] * 2).to(self.device) noise_pred = self.model.apply_model(x_in, t_in, cond) noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) # dps if dps_scale > 0: with torch.enable_grad(): t_batch = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device) * 0 + t x_hat_latents = self.model.predict_start_from_noise(latents.requires_grad_(True), t_batch, noise_pred) x_hat = self.decode_latents(x_hat_latents) x_hat = operator.forward(x_hat) norm = torch.linalg.norm((pred_rgb_256-x_hat).reshape(pred_rgb_256.shape[0], -1), dim=-1) guidance_score = torch.autograd.grad(norm.sum(), latents, retain_graph=True)[0] if (not save_guidance_path is None) and i % (len(self.scheduler.timesteps)//20) == 0: x_t = self.decode_latents(latents) intermidates.append(torch.cat([x_hat, x_t, pred_rgb_256, pred_rgb_256-x_hat], dim=-2).detach().cpu()) # print("before", noise_pred[0, 2, 10, 16:22], noise_pred.shape, dps_scale) logger.debug(f"Guidance loss: {norm}") noise_pred = noise_pred + dps_scale * guidance_score if self.use_ddim: latents = self.scheduler.step(noise_pred, t, latents, eta=ddim_eta)['prev_sample'] else: latents = self.scheduler.step(noise_pred.clone().detach(), t, latents)['prev_sample'] if dps_scale > 0: del x_hat del guidance_score del noise_pred del x_hat_latents del norm imgs = self.decode_latents(latents) viz_images = torch.cat([pred_rgb_256, imgs],dim=-1)[:1] if not save_guidance_path is None and len(intermidates) > 0: save_image(viz_images, save_guidance_path) viz_images = torch.cat(intermidates,dim=-1)[:1] save_image(viz_images, save_guidance_path+"all.jpg") # transform back to original images img_ori_size = self.restore_crop(imgs, pred_rgb, crop_idx_all) if not save_guidance_path is None: img_ori_size_save = rearrange(img_ori_size, 'b h w c -> b c h w')[:1] save_image(img_ori_size_save, save_guidance_path+"_out.jpg") return img_ori_size def decode_latents(self, latents): # zs: [B, 4, 32, 32] Latent space image # with self.model.ema_scope(): imgs = self.model.decode_first_stage(latents) imgs = (imgs / 2 + 0.5).clamp(0, 1) return imgs # [B, 3, 256, 256] RGB space image def encode_imgs(self, imgs): # imgs: [B, 3, 256, 256] RGB space image # with self.model.ema_scope(): imgs = imgs * 2 - 1 # latents = torch.cat([self.model.get_first_stage_encoding(self.model.encode_first_stage(img.unsqueeze(0))) for img in imgs], dim=0) latents = self.model.get_first_stage_encoding(self.model.encode_first_stage(imgs)) return latents # [B, 4, 32, 32] Latent space image