import inspect from tkinter import Image from typing import List, Optional, Union import numpy as np import torch import PIL from PIL import Image from tqdm.auto import tqdm from diffusion_arch import DensePosteriorConditionalUNet from guided_diffusion.script_util import create_gaussian_diffusion import torch.nn.functional as F import torchvision.transforms.functional as TF from einops import rearrange from kornia.morphology import dilation from tqdm import tqdm def preprocess_image(image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image.transpose(2,0,1)).unsqueeze(0) return 2.0 * image - 1.0 def preprocess_mask(mask): mask = mask.convert("L") w, h = mask.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 mask = mask.resize((w, h), resample=PIL.Image.NEAREST) mask = np.array(mask).astype(np.float32) / 255.0 mask = torch.from_numpy(np.repeat(mask[None, ...], 3, axis=0)).unsqueeze(0) mask[mask > 0] = 1 return mask class DiffusionPipeline(): def __init__(self, device): super().__init__() self.device = device self.model = DensePosteriorConditionalUNet( in_channels=9, model_channels=256, out_channels=6, num_res_blocks=2, attention_resolutions=[8, 16, 32], dropout=0.0, channel_mult=(1, 1, 2, 2, 4, 4), num_classes=None, use_checkpoint=False, use_fp16=False, num_heads=4, num_head_channels=64, num_heads_upsample=-1, use_scale_shift_norm=True, resblock_updown=True, use_new_attention_order=True ) self.model.eval() self.model.to(self.device) self.model.load_state_dict(torch.load('net_g_400000.pth', map_location='cpu')["params_ema"], strict=True) @torch.no_grad() def __call__(self, lq, mask, dkernel, diffusion_step): self.eval_gaussian_diffusion = create_gaussian_diffusion( steps=1000, learn_sigma=True, noise_schedule='linear', use_kl=False, timestep_respacing="ddim" + str(diffusion_step), predict_xstart=False, rescale_timesteps=False, rescale_learned_sigmas=False, p2_gamma=1, p2_k=1, ) ow, oh = lq.size # preprocess image lq = preprocess_image(lq).to(self.device) # preprocess mask mask = preprocess_mask(mask).to(self.device) mask = dilation(mask, torch.ones(dkernel, dkernel, device=self.device)) # return Image.fromarray(np.uint8(torch.cat(((lq / 2 + 0.5).clamp(0, 1), mask), dim=2).cpu().permute(0, 2, 3, 1).numpy()[0] * 255.)) #======== PADDING FORWARDING ============ stride = 64 kernel_size = 256 _, _, h, w = mask.shape mask = F.unfold(mask, kernel_size=kernel_size, stride=stride) lq = F.unfold(lq, kernel_size=kernel_size, stride=stride) n, c, l = mask.shape mask = rearrange(mask, 'n (c3 h w) l -> (n l) c3 h w', h=kernel_size, w=kernel_size) lq = rearrange(lq, 'n (c3 h w) l -> (n l) c3 h w', h=kernel_size, w=kernel_size) #======== PADDING END ============ #======== FORWARDING ============ sub_imgs = [] for (sub_lq, sub_mask) in zip(lq.unsqueeze(1), mask.unsqueeze(1)): if torch.sum(sub_mask) > 1: img = torch.randn_like(sub_lq, device=self.device) indices = list(range(self.eval_gaussian_diffusion.num_timesteps))[::-1] for i in indices: t = torch.tensor([i] * img.size(0), device=self.device) img = img * sub_mask + self.eval_gaussian_diffusion.q_sample(sub_lq, t) * (1 - sub_mask) out = self.eval_gaussian_diffusion.p_mean_variance(self.model, img.contiguous(), t, model_kwargs={'latent': torch.cat((sub_lq, sub_mask), dim=1)}) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(img.shape) - 1))) ) # no noise when t == 0 img = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * torch.randn_like(img, device=self.device) sub_imgs.append(img) else: sub_imgs.append(sub_lq) img = torch.cat(sub_imgs, dim=0) #======== PADDING BACKWARDING ============ img = rearrange(img, '(n l) c3 h w -> n (c3 h w) l', h=kernel_size, w=kernel_size, l=l) img = F.fold(img, (h, w), kernel_size=kernel_size, stride=stride) norm_map = F.fold(F.unfold(torch.ones_like(img), kernel_size, stride=stride), (h, w), kernel_size, stride=stride) img /= norm_map img = (img / 2 + 0.5).clamp(0, 1) img = img.cpu().permute(0, 2, 3, 1).numpy()[0] img = Image.fromarray(np.uint8(img * 255.)) img = img.resize((ow, oh), resample=PIL.Image.LANCZOS) return img @torch.no_grad() def quick_solve(self, lq, mask, dkernel, diffusion_step): self.eval_gaussian_diffusion = create_gaussian_diffusion( steps=1000, learn_sigma=True, noise_schedule='linear', use_kl=False, timestep_respacing="ddim" + str(diffusion_step), predict_xstart=False, rescale_timesteps=False, rescale_learned_sigmas=False, p2_gamma=1, p2_k=1, ) ow, oh = lq.size lq = lq.resize((512, 512), resample=Image.LANCZOS) mask = mask.resize((512, 512), resample=Image.NEAREST) # preprocess image lq = preprocess_image(lq).to(self.device) # preprocess mask mask = preprocess_mask(mask).to(self.device) mask = dilation(mask, torch.ones(dkernel, dkernel, device=self.device)) # return Image.fromarray(np.uint8(torch.cat(((lq / 2 + 0.5).clamp(0, 1), mask), dim=2).cpu().permute(0, 2, 3, 1).numpy()[0] * 255.)) img = torch.randn_like(lq, device=self.device) indices = list(range(self.eval_gaussian_diffusion.num_timesteps))[::-1] for i in indices: t = torch.tensor([i] * img.size(0), device=self.device) img = img * mask + self.eval_gaussian_diffusion.q_sample(lq, t) * (1 - mask) out = self.eval_gaussian_diffusion.p_mean_variance(self.model, img.contiguous(), t, model_kwargs={'latent': torch.cat((lq, mask), dim=1)}) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(img.shape) - 1))) ) # no noise when t == 0 img = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * torch.randn_like(img, device=self.device) yield Image.fromarray(np.uint8((out["pred_xstart"] / 2 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()[0] * 255.)).resize((ow, oh), resample=Image.LANCZOS) yield Image.fromarray(np.uint8((img / 2 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()[0] * 255.)).resize((ow, oh), resample=Image.LANCZOS)