Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| import numpy as np | |
| from omegaconf import OmegaConf | |
| from pathlib import Path | |
| import cv2 | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.cuda.amp import custom_bwd, custom_fwd | |
| from torchvision.utils import save_image | |
| from torchvision.ops import masks_to_boxes | |
| from torchvision.transforms import Resize | |
| from diffusers import DDIMScheduler, DDPMScheduler | |
| from einops import rearrange, repeat | |
| from tqdm import tqdm | |
| import sys | |
| from os import path | |
| sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) | |
| sys.path.append("./models/") | |
| from loguru import logger | |
| from ldm.util import instantiate_from_config | |
| from ldm.models.diffusion.ddim import DDIMSampler | |
| from ldm.modules.diffusionmodules.util import extract_into_tensor | |
| # load model | |
| def load_model_from_config(config, ckpt, device, vram_O=False, verbose=True): | |
| pl_sd = torch.load(ckpt, map_location='cpu') | |
| if 'global_step' in pl_sd and verbose: | |
| logger.info(f'Global Step: {pl_sd["global_step"]}') | |
| sd = pl_sd['state_dict'] | |
| model = instantiate_from_config(config.model) | |
| m, u = model.load_state_dict(sd, strict=False) | |
| if len(m) > 0: | |
| logger.warning('missing keys: \n', m) | |
| if len(u) > 0: | |
| logger.warning('unexpected keys: \n', u) | |
| # manually load ema and delete it to save GPU memory | |
| if model.use_ema: | |
| logger.debug('loading EMA...') | |
| model.model_ema.copy_to(model.model) | |
| del model.model_ema | |
| if vram_O: | |
| # we don't need decoder | |
| del model.first_stage_model.decoder | |
| torch.cuda.empty_cache() | |
| model.eval().to(device) | |
| # model.first_stage_model.train = True | |
| # model.first_stage_model.train() | |
| for param in model.first_stage_model.parameters(): | |
| param.requires_grad = True | |
| return model | |
| class MateralDiffusion(nn.Module): | |
| def __init__(self, device, fp16, | |
| config=None, | |
| ckpt=None, vram_O=False, t_range=[0.02, 0.98], opt=None, use_ddim=True): | |
| super().__init__() | |
| self.device = device | |
| self.fp16 = fp16 | |
| self.vram_O = vram_O | |
| self.t_range = t_range | |
| self.opt = opt | |
| self.config = OmegaConf.load(config) | |
| # TODO: seems it cannot load into fp16... | |
| self.model = load_model_from_config(self.config, ckpt, device=self.device, vram_O=vram_O, verbose=True) | |
| # timesteps: use diffuser for convenience... hope it's alright. | |
| self.num_train_timesteps = self.config.model.params.timesteps | |
| self.use_ddim = use_ddim | |
| if self.use_ddim: | |
| self.scheduler = DDIMScheduler( | |
| self.num_train_timesteps, | |
| self.config.model.params.linear_start, | |
| self.config.model.params.linear_end, | |
| beta_schedule='scaled_linear', | |
| clip_sample=False, | |
| set_alpha_to_one=False, | |
| steps_offset=1, | |
| ) | |
| print("Using DDIM...") | |
| else: | |
| self.scheduler = DDPMScheduler( | |
| self.num_train_timesteps, | |
| self.config.model.params.linear_start, | |
| self.config.model.params.linear_end, | |
| beta_schedule='scaled_linear', | |
| clip_sample=False, | |
| ) | |
| print("Using DDPM...") | |
| self.min_step = int(self.num_train_timesteps * t_range[0]) | |
| self.max_step = int(self.num_train_timesteps * t_range[1]) | |
| self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience | |
| def get_input(self, x): | |
| if len(x.shape) == 3: | |
| x = x[..., None] | |
| x = rearrange(x, 'b h w c -> b c h w') | |
| x = x.to(memory_format=torch.contiguous_format).float() | |
| return x | |
| def center_crop(self, img, mask, return_uv=False, mask_ratio=.8, image_size=256): | |
| margin = np.round((1 - mask_ratio) * image_size).astype(int) | |
| resizer = Resize([np.round(image_size-margin*2).astype(int), | |
| np.round(image_size-margin*2).astype(int)]) | |
| # img ~ batch, h, w, 3 | |
| # mask ~ batch, h, w, 3 | |
| # ensure border is 0, as grid sampler only support border or zeros padding | |
| # But we need the one padding | |
| batch_size = img.shape[0] | |
| min_max_uv = masks_to_boxes(mask[..., -1] > 0.5) | |
| min_uv, max_uv = min_max_uv[..., [1,0]].long(), (min_max_uv[..., [3,2]] + 1).long() | |
| # fill back ground to ones | |
| img = (img + (mask[..., -1:] <= 0.5)).clamp(0, 1) | |
| img = rearrange(img, 'b h w c -> b c h w') | |
| ori_size = torch.tensor(img.shape[-2:]).to(min_max_uv.device).reshape(1, 2).expand(img.shape[0], -1) | |
| crooped_imgs = [] | |
| for batch_idx in range(batch_size): | |
| # print(min_uv, max_uv, margin) | |
| img_crop = img[batch_idx][:, min_uv[batch_idx, 0]:max_uv[batch_idx, 0], | |
| min_uv[batch_idx,1]:max_uv[batch_idx, 1]] | |
| img_crop = resizer(img_crop) | |
| img_out = torch.ones(3, image_size, image_size).to(img.device) | |
| img_out[:, margin:image_size-margin, margin:image_size-margin] = img_crop | |
| crooped_imgs.append(img_out) | |
| img_new = torch.stack(crooped_imgs, dim=0) | |
| img_new = rearrange(img_new, 'b c h w -> b h w c') | |
| crop_uv = torch.stack([ori_size[:, 0], ori_size[:, 1], min_uv[:, 0], min_uv[:, 1], max_uv[:, 0], max_uv[:, 1], max_uv[:, 1]*0+margin], dim=-1).float() | |
| if return_uv: | |
| return img_new, crop_uv | |
| return img_new | |
| def center_crop_aspect_ratio(self, img, mask, return_uv=False, mask_ratio=.8, image_size=256): | |
| # img ~ batch, h, w, 3 | |
| # mask ~ batch, h, w, 3 | |
| # ensure border is 0, as grid sampler only support border or zeros padding | |
| # But we need the one padding | |
| boarder_mask = torch.zeros_like(mask) | |
| boarder_mask[:, 1:-1, 1:-1] = 1 | |
| mask = mask * boarder_mask | |
| # print(f"mask: {mask.shape}, {(mask[..., -1] > 0.5).sum}") | |
| min_max_uv = masks_to_boxes(mask[..., -1] > 0.5) | |
| min_uv, max_uv = min_max_uv[..., [1,0]], min_max_uv[..., [3,2]] | |
| # fill back ground to ones | |
| img = (img + (mask[..., -1:] <= 0.5)).clamp(0, 1) | |
| img = rearrange(img, 'b h w c -> b c h w') | |
| ori_size = torch.tensor(img.shape[-2:]).to(min_max_uv.device).reshape(1, 2).expand(img.shape[0], -1) | |
| crop_length = torch.div((max_uv - min_uv), 2, rounding_mode='floor') | |
| half_size = torch.max(crop_length, dim=-1, keepdim=True)[0] | |
| center_uv = min_uv + crop_length | |
| # generate grid | |
| target_size = image_size | |
| grid_x, grid_y = torch.meshgrid(torch.arange(0, target_size, 1, device=min_max_uv.device), \ | |
| torch.arange(0, target_size, 1, device=min_max_uv.device), \ | |
| indexing='ij') | |
| normalized_xy = torch.stack([(grid_x) / (target_size - 1), grid_y / (target_size - 1)], dim=-1) # [0,1] | |
| normalized_xy = (normalized_xy - 0.5) / mask_ratio + 0.5 | |
| normalized_xy = normalized_xy[None].expand(img.shape[0], -1, -1, -1) | |
| ori_crop_size = 2 * half_size + 1 | |
| xy_scale = (ori_crop_size-1) / (ori_size - 1) | |
| normalized_xy = normalized_xy * xy_scale.reshape(-1, 1, 1, 2)[..., [0,1]] | |
| xy_shift = (center_uv - half_size) / (ori_size - 1) | |
| normalized_xy = normalized_xy + xy_shift.reshape(-1, 1, 1, 2)[..., [0,1]] | |
| normalized_xy = normalized_xy * 2 - 1 # [-1,1] | |
| # normalized_xy = normalized_xy / mask_ratio | |
| img_new = F.grid_sample(img, normalized_xy[..., [1,0]], padding_mode='border', align_corners=True) | |
| crop_uv = torch.stack([ori_size[:, 0], ori_size[:, 1], half_size[..., 0]*0.0 + mask_ratio, half_size[..., 0], center_uv[:, 0], center_uv[:, 1]], dim=-1).float() | |
| img_new = rearrange(img_new, 'b c h w -> b h w c') | |
| if return_uv: | |
| return img_new, crop_uv | |
| return img_new | |
| def restore_crop(self, img, img_ori, crop_idx): | |
| ori_size, min_uv, max_uv, margin = crop_idx[:, :2].long(), crop_idx[:, 2:4].long(), crop_idx[:, 4:6].long(), crop_idx[0, 6].long().item() | |
| batch_size = img.shape[0] | |
| all_images = [] | |
| for batch_idx in range(batch_size): | |
| img_out = torch.ones(3, ori_size[batch_idx][0], ori_size[batch_idx][1]).to(img.device) | |
| cropped_size = max_uv[batch_idx] - min_uv[batch_idx] | |
| resizer = Resize([cropped_size[0], cropped_size[1]]) | |
| net_size = img[batch_idx].shape[-1] | |
| img_crop = resizer(img[batch_idx][:, margin:net_size-margin, margin:net_size-margin]) | |
| img_out[:, min_uv[batch_idx, 0]:max_uv[batch_idx, 0], | |
| min_uv[batch_idx,1]:max_uv[batch_idx, 1]] = img_crop | |
| all_images.append(img_out) | |
| all_images = torch.stack(all_images, dim=0) | |
| all_images = rearrange(all_images, 'b c h w -> b h w c') | |
| return all_images | |
| def restore_crop_aspect_ratio(self, img, img_ori, crop_idx): | |
| ori_size, mask_ratio, half_size, center_uv = crop_idx[:, :2].long(), crop_idx[:, 2:3], crop_idx[:, 3:4].long(), crop_idx[:, 4:].long() | |
| img[:, :, 0, :] = 1 | |
| img[:, :, -1, :] = 1 | |
| img[:, :, :, 0] = 1 | |
| img[:, :, :, -1] = 1 | |
| ori_crop_size = 2*half_size + 1 | |
| grid_x, grid_y = torch.meshgrid(torch.arange(0, ori_size[0, 0].item(), 1, device=img.device), \ | |
| torch.arange(0, ori_size[0, 1].item(), 1, device=img.device), \ | |
| indexing='ij') | |
| normalized_xy = torch.stack([grid_x, grid_y], dim=-1)[None].expand(img.shape[0], -1, -1, -1) - \ | |
| (center_uv - half_size).reshape(-1, 1, 1, 2)[..., [0,1]] | |
| normalized_xy = normalized_xy / (ori_crop_size-1).reshape(-1, 1, 1, 1) | |
| normalized_xy = (2*normalized_xy - 1) * mask_ratio.reshape(-1, 1, 1, 1) | |
| sample_start = (center_uv - half_size) | |
| # print(normalized_xy[0][sample_start[0][0], sample_start[0][1]], mask_ratio) | |
| img_out = F.grid_sample(img, normalized_xy[..., [1,0]], padding_mode='border', align_corners=True) | |
| img_out = rearrange(img_out, 'b c h w -> b h w c') | |
| return img_out | |
| def _image2diffusion(self, embeddings, pred_rgb, mask, image_size=256): | |
| # pred_rgb: tensor [1, 3, H, W] in [0, 1] | |
| # assert pred_rgb.w | |
| assert len(pred_rgb.shape) == 4, f"except 4 dim tensor, got: {pred_rgb.shape}" | |
| cond_img = embeddings["cond_img"] | |
| cond_img = self.center_crop(cond_img, mask, mask_ratio=1.0, image_size=image_size) | |
| pred_rgb_256, crop_idx_all = self.center_crop(pred_rgb, mask, return_uv=True, mask_ratio=1.0, image_size=image_size) | |
| # print(f"pred_rgb_256: {pred_rgb_256.min()} {pred_rgb_256.max()} {pred_rgb_256.shape} {cond_img.shape}") | |
| mask_img = self.center_crop(1 - mask.expand(-1, -1, -1, 3), mask, mask_ratio=1.0, image_size=image_size) | |
| xc = self.get_input(cond_img) | |
| pred_rgb_256 = self.get_input(pred_rgb_256) | |
| return pred_rgb_256, crop_idx_all, xc | |
| def _get_condition(self, xc, with_uncondition=False): | |
| # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%. | |
| # z.shape: [8, 4, 64, 64]; c.shape: [8, 1, 768] | |
| # print('=========== xc shape ===========', xc.shape) | |
| # print(xc.shape, xc.min(), xc.max(), self.model.use_clip_embdding) | |
| xc = xc * 2 - 1 | |
| cond = {} | |
| clip_emb = self.model.get_learned_conditioning(xc if self.model.use_clip_embdding else [""]).detach() | |
| c_concat = self.model.encode_first_stage((xc.to(self.device))).mode().detach() | |
| # print(clip_emb.shape, clip_emb.min(), clip_emb.max(), self.model.use_clip_embdding) | |
| if with_uncondition: | |
| cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)] | |
| cond['c_concat'] = [torch.cat([torch.zeros_like(c_concat).to(self.device), c_concat], dim=0)] | |
| else: | |
| cond['c_crossattn'] = [clip_emb] | |
| cond['c_concat'] = [c_concat] | |
| return cond | |
| def __call__(self, embeddings, pred_rgb, mask, guidance_scale=3, dps_scale=0.2, as_latent=False, grad_scale=1, save_guidance_path:Path=None, | |
| ddim_steps=200, ddim_eta=1, operator=None): | |
| # todo: The upsacle is currectly hard-coded | |
| upscale = 1 | |
| # with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| pred_rgb_256, crop_idx_all, xc = self._image2diffusion(embeddings, pred_rgb, mask, image_size=256*upscale) | |
| cond = self._get_condition(xc, with_uncondition=True) | |
| assert pred_rgb_256.shape[-1] == pred_rgb_256.shape[-2], f"Expect image of square size, get {pred_rgb.shape}" | |
| latents = torch.randn_like(self.encode_imgs(pred_rgb_256)) | |
| if self.use_ddim: | |
| self.scheduler.set_timesteps(ddim_steps) | |
| else: | |
| self.scheduler.set_timesteps(self.num_train_timesteps) | |
| intermidates = [] | |
| for i, t in tqdm(enumerate(self.scheduler.timesteps)): | |
| x_in = torch.cat([latents] * 2) | |
| t_in = torch.cat([t.view(1).expand(latents.shape[0])] * 2).to(self.device) | |
| noise_pred = self.model.apply_model(x_in, t_in, cond) | |
| noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| # dps | |
| if dps_scale > 0: | |
| with torch.enable_grad(): | |
| t_batch = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device) * 0 + t | |
| x_hat_latents = self.model.predict_start_from_noise(latents.requires_grad_(True), t_batch, noise_pred) | |
| x_hat = self.decode_latents(x_hat_latents) | |
| x_hat = operator.forward(x_hat) | |
| norm = torch.linalg.norm((pred_rgb_256-x_hat).reshape(pred_rgb_256.shape[0], -1), dim=-1) | |
| guidance_score = torch.autograd.grad(norm.sum(), latents, retain_graph=True)[0] | |
| if (not save_guidance_path is None) and i % (len(self.scheduler.timesteps)//20) == 0: | |
| x_t = self.decode_latents(latents) | |
| intermidates.append(torch.cat([x_hat, x_t, pred_rgb_256, pred_rgb_256-x_hat], dim=-2).detach().cpu()) | |
| # print("before", noise_pred[0, 2, 10, 16:22], noise_pred.shape, dps_scale) | |
| logger.debug(f"Guidance loss: {norm}") | |
| noise_pred = noise_pred + dps_scale * guidance_score | |
| if self.use_ddim: | |
| latents = self.scheduler.step(noise_pred, t, latents, eta=ddim_eta)['prev_sample'] | |
| else: | |
| latents = self.scheduler.step(noise_pred.clone().detach(), t, latents)['prev_sample'] | |
| if dps_scale > 0: | |
| del x_hat | |
| del guidance_score | |
| del noise_pred | |
| del x_hat_latents | |
| del norm | |
| imgs = self.decode_latents(latents) | |
| viz_images = torch.cat([pred_rgb_256, imgs],dim=-1)[:1] | |
| if not save_guidance_path is None and len(intermidates) > 0: | |
| save_image(viz_images, save_guidance_path) | |
| viz_images = torch.cat(intermidates,dim=-1)[:1] | |
| save_image(viz_images, save_guidance_path+"all.jpg") | |
| # transform back to original images | |
| img_ori_size = self.restore_crop(imgs, pred_rgb, crop_idx_all) | |
| if not save_guidance_path is None: | |
| img_ori_size_save = rearrange(img_ori_size, 'b h w c -> b c h w')[:1] | |
| save_image(img_ori_size_save, save_guidance_path+"_out.jpg") | |
| return img_ori_size | |
| def decode_latents(self, latents): | |
| # zs: [B, 4, 32, 32] Latent space image | |
| # with self.model.ema_scope(): | |
| imgs = self.model.decode_first_stage(latents) | |
| imgs = (imgs / 2 + 0.5).clamp(0, 1) | |
| return imgs # [B, 3, 256, 256] RGB space image | |
| def encode_imgs(self, imgs): | |
| # imgs: [B, 3, 256, 256] RGB space image | |
| # with self.model.ema_scope(): | |
| imgs = imgs * 2 - 1 | |
| # latents = torch.cat([self.model.get_first_stage_encoding(self.model.encode_first_stage(img.unsqueeze(0))) for img in imgs], dim=0) | |
| latents = self.model.get_first_stage_encoding(self.model.encode_first_stage(imgs)) | |
| return latents # [B, 4, 32, 32] Latent space image |