import numpy as np import PIL from PIL import Image import torch from diffusion_arch import ILVRUNetModel, ConditionalUNetModel from guided_diffusion.script_util import create_gaussian_diffusion import torch.nn.functional as F import torchvision.transforms.functional as TF from torchvision.utils import make_grid def preprocess_image(image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image.transpose(2,0,1)).unsqueeze(0) return 2.0 * image - 1.0 def preprocess_mask(mask): mask = mask.convert("L") w, h = mask.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 mask = mask.resize((w, h), resample=PIL.Image.NEAREST) mask = np.array(mask).astype(np.float32) / 255.0 mask = torch.from_numpy(np.repeat(mask[None, ...], 3, axis=0)).unsqueeze(0) mask[mask > 0] = 1 return mask class DiffusionPipeline(): def __init__(self, device): super().__init__() self.device = device diffusion_model = ILVRUNetModel( in_channels=3, model_channels=128, out_channels=6, num_res_blocks=1, attention_resolutions=[16], channel_mult=(1, 1, 2, 2, 4, 4), num_classes=None, use_checkpoint=False, use_fp16=False, num_heads=4, num_head_channels=64, num_heads_upsample=-1, use_scale_shift_norm=True, resblock_updown=True, use_new_attention_order=False ) diffusion_model = diffusion_model.to(device) diffusion_model = diffusion_model.eval() ilvr_pretraining = torch.load('./ffhq_10m.pt', map_location='cpu') diffusion_model.load_state_dict(ilvr_pretraining) self.diffusion_model = diffusion_model diffusion_restoration_model = ConditionalUNetModel( in_channels=3, model_channels=128, out_channels=6, num_res_blocks=1, attention_resolutions=[16], dropout=0.0, channel_mult=(1, 1, 2, 2, 4, 4), num_classes=None, use_checkpoint=False, use_fp16=False, num_heads=4, num_head_channels=64, num_heads_upsample=-1, use_scale_shift_norm=True, resblock_updown=True, use_new_attention_order=False ) diffusion_restoration_model = diffusion_restoration_model.to(device) diffusion_restoration_model = diffusion_restoration_model.eval() state_dict = torch.load('./net_g_250000.pth', map_location='cpu') diffusion_restoration_model.load_state_dict(state_dict['params']) self.diffusion_restoration_model = diffusion_restoration_model @torch.no_grad() def __call__(self, lq, diffusion_step, binoising_step, grid_size): lq = lq.convert("RGB").resize((256, 256), resample=Image.LANCZOS) eval_gaussian_diffusion = create_gaussian_diffusion( steps=1000, learn_sigma=True, noise_schedule='linear', use_kl=False, timestep_respacing=str(int(diffusion_step)), predict_xstart=False, rescale_timesteps=False, rescale_learned_sigmas=False, ) ow, oh = lq.size # preprocess image lq_img_th = preprocess_image(lq).to(self.device) lq_img_th = lq_img_th.repeat([grid_size, 1, 1, 1]) img = torch.randn_like(lq_img_th, device=self.device) s_img = torch.randn_like(lq_img_th, device=self.device) indices = list(range(eval_gaussian_diffusion.num_timesteps))[::-1] for i in indices: t = torch.tensor([i] * lq_img_th.size(0), device=self.device) out = eval_gaussian_diffusion.p_mean_variance(self.diffusion_restoration_model, s_img, t, model_kwargs={'lq': lq_img_th}) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(img.shape) - 1))) ) # no noise when t == 0 s_img = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * torch.randn_like(img, device=self.device) s_img_pred = out["pred_xstart"] if i < binoising_step: model_output = eval_gaussian_diffusion._wrap_model(self.diffusion_restoration_model)(img, t, lq=lq_img_th) B, C = img.shape[:2] model_output, model_var_values = torch.split(model_output, C, dim=1) pred_xstart = eval_gaussian_diffusion._predict_xstart_from_eps(img, t, model_output).clamp(-1, 1) img = eval_gaussian_diffusion.q_sample(pred_xstart, t) out = eval_gaussian_diffusion.p_mean_variance(self.diffusion_model, img, t) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(img.shape) - 1))) ) # no noise when t == 0 img = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * torch.randn_like(img, device=self.device) img_pred = out["pred_xstart"] if i % 2 == 0: yield [Image.fromarray(np.uint8((make_grid(s_img_pred) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.)), Image.fromarray(np.uint8((make_grid(img_pred) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.))] yield [Image.fromarray(np.uint8((make_grid(s_img) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.)), Image.fromarray(np.uint8((make_grid(img) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.))]