#!/usr/bin/env python # -*- coding:utf-8 -*- # Power by Zongsheng Yue 2022-07-13 16:59:27 import os import random import numpy as np from math import ceil from pathlib import Path from einops import rearrange from omegaconf import OmegaConf from skimage import img_as_ubyte from ResizeRight.resize_right import resize from utils import util_net from utils import util_image from utils import util_common import torch import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from basicsr.utils import img2tensor from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.utils.realesrgan_utils import RealESRGANer from facelib.utils.face_restoration_helper import FaceRestoreHelper class BaseSampler: def __init__(self, configs): ''' Input: configs: config, see the yaml file in folder ./configs/sample/ ''' self.configs = configs self.display = configs.display self.diffusion_cfg = configs.diffusion self.setup_dist() # setup distributed training: self.num_gpus, self.rank self.setup_seed() # setup seed self.build_model() def setup_seed(self, seed=None): seed = self.configs.seed if seed is None else seed seed += (self.rank+1) * 10000 if self.rank == 0 and self.display: print(f'Setting random seed {seed}') random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def setup_dist(self): if torch.cuda.is_available(): self.device = torch.device('cuda') print(f'Runing on GPU...') else: self.device = torch.device('cpu') print(f'Runing on CPU...') self.rank = 0 def build_model(self): obj = util_common.get_obj_from_str(self.configs.diffusion.target) self.diffusion = obj(**self.configs.diffusion.params) obj = util_common.get_obj_from_str(self.configs.model.target) model = obj(**self.configs.model.params).to(self.device) if not self.configs.model.ckpt_path is None: self.load_model(model, self.configs.model.ckpt_path) self.model = model self.model.eval() def load_model(self, model, ckpt_path=None): if not ckpt_path is None: if self.rank == 0 and self.display: print(f'Loading from {ckpt_path}...') ckpt = torch.load(ckpt_path, map_location=f"cuda:{self.rank}") util_net.reload_model(model, ckpt) if self.rank == 0 and self.display: print('Loaded Done') def reset_diffusion(self, diffusion_cfg): self.diffusion = create_gaussian_diffusion(**diffusion_cfg) class DifIRSampler(BaseSampler): def build_model(self): super().build_model() if not self.configs.model_ir is None: obj = util_common.get_obj_from_str(self.configs.model_ir.target) model_ir = obj(**self.configs.model_ir.params).cuda() if not self.configs.model_ir.ckpt_path is None: self.load_model(model_ir, self.configs.model_ir.ckpt_path) self.model_ir = model_ir self.model_ir.eval() if not self.configs.aligned: # face dection model self.face_helper = FaceRestoreHelper( self.configs.detection.upscale, face_size=self.configs.im_size, crop_ratio=(1, 1), det_model = self.configs.detection.det_model, save_ext='png', use_parse=True, device=self.device, ) # background super-resolution if self.configs.background_enhance or self.configs.face_upsample: bg_model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2, ) self.bg_model = RealESRGANer( scale=2, model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', model=bg_model, tile=400, tile_pad=10, pre_pad=0, half=True, device=torch.device(f'cuda:{self.rank}'), ) # need to set False in CPU mode def sample_func_ir_aligned( self, y0, start_timesteps=None, post_fun=None, model_kwargs_ir=None, need_restoration=True, ): ''' Input: y0: n x c x h x w torch tensor, low-quality image, [0, 1], RGB or, h x w x c, numpy array, [0, 255], uint8, BGR start_timesteps: integer, range [0, num_timesteps-1], for accelerated sampling (e.g., 'ddim250'), range [0, 249] post_fun: post-processing for the enhanced image model_kwargs_ir: additional parameters for restoration model Output: sample: n x c x h x w, torch tensor, [0,1], RGB ''' if not isinstance(y0, torch.Tensor): y0 = img2tensor(y0, bgr2rgb=True, float32=True).unsqueeze(0) / 255. # 1 x c x h x w, [0,1] if start_timesteps is None: start_timesteps = self.diffusion.num_timesteps if post_fun is None: post_fun = lambda x: util_image.normalize_th( im=x, mean=0.5, std=0.5, reverse=False, ) # basical image restoration device = next(self.model.parameters()).device y0 = y0.to(device=device, dtype=torch.float32) h_old, w_old = y0.shape[2:4] if not (h_old == self.configs.im_size and w_old == self.configs.im_size): y0 = resize(y0, out_shape=(self.configs.im_size,) * 2).to(torch.float32) if need_restoration: with torch.no_grad(): if model_kwargs_ir is None: im_hq = self.model_ir(y0) else: im_hq = self.model_ir(y0, **model_kwargs_ir) else: im_hq = y0 im_hq.clamp_(0.0, 1.0) # diffuse for im_hq yt = self.diffusion.q_sample( x_start=post_fun(im_hq), t=torch.tensor([start_timesteps,]*im_hq.shape[0], device=device), ) assert yt.shape[-1] == self.configs.im_size and yt.shape[-2] == self.configs.im_size if 'ddim' in self.configs.diffusion.params.timestep_respacing: sample = self.diffusion.ddim_sample_loop( self.model, shape=yt.shape, noise=yt, start_timesteps=start_timesteps, clip_denoised=True, denoised_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, ) else: sample = self.diffusion.p_sample_loop( self.model, shape=yt.shape, noise=yt, start_timesteps=start_timesteps, clip_denoised=True, denoised_fn=None, model_kwargs=None, device=None, progress=False, ) sample = util_image.normalize_th(sample, reverse=True).clamp(0.0, 1.0) if not (h_old == self.configs.im_size and w_old == self.configs.im_size): sample = resize(sample, out_shape=(h_old, w_old)).clamp(0.0, 1.0) return sample, im_hq def sample_func_bfr_unaligned( self, y0, bs=16, start_timesteps=None, post_fun=None, model_kwargs_ir=None, need_restoration=True, only_center_face=False, draw_box=False, ): ''' Input: y0: h x w x c numpy array, uint8, BGR bs: batch size for face restoration upscale: upsampling factor for the restorated image start_timesteps: integer, range [0, num_timesteps-1], for accelerated sampling (e.g., 'ddim250'), range [0, 249] post_fun: post-processing for the enhanced image model_kwargs_ir: additional parameters for restoration model only_center_face: draw_box: draw a box for each face Output: restored_img: h x w x c, numpy array, uint8, BGR restored_faces: list, h x w x c, numpy array, uint8, BGR cropped_faces: list, h x w x c, numpy array, uint8, BGR ''' def _process_batch(cropped_faces_list): length = len(cropped_faces_list) cropped_face_t = np.stack( img2tensor(cropped_faces_list, bgr2rgb=True, float32=True), axis=0) / 255. cropped_face_t = torch.from_numpy(cropped_face_t).to(torch.device(f"cuda:{self.rank}")) restored_faces = self.sample_func_ir_aligned( cropped_face_t, start_timesteps=start_timesteps, post_fun=post_fun, model_kwargs_ir=model_kwargs_ir, need_restoration=need_restoration, )[0] # [0, 1], b x c x h x w return restored_faces assert not self.configs.aligned self.face_helper.clean_all() self.face_helper.read_image(y0) num_det_faces = self.face_helper.get_face_landmarks_5( only_center_face=only_center_face, resize=640, eye_dist_threshold=5, ) # align and warp each face self.face_helper.align_warp_face() num_cropped_face = len(self.face_helper.cropped_faces) if num_cropped_face > bs: restored_faces = [] for idx_start in range(0, num_cropped_face, bs): idx_end = idx_start + bs if idx_start + bs < num_cropped_face else num_cropped_face current_cropped_faces = self.face_helper.cropped_faces[idx_start:idx_end] current_restored_faces = _process_batch(current_cropped_faces) current_restored_faces = util_image.tensor2img( list(current_restored_faces.split(1, dim=0)), rgb2bgr=True, min_max=(0, 1), out_type=np.uint8, ) restored_faces.extend(current_restored_faces) else: restored_faces = _process_batch(self.face_helper.cropped_faces) restored_faces = util_image.tensor2img( list(restored_faces.split(1, dim=0)), rgb2bgr=True, min_max=(0, 1), out_type=np.uint8, ) for xx in restored_faces: self.face_helper.add_restored_face(xx) # paste_back if self.configs.background_enhance: bg_img = self.bg_model.enhance(y0, outscale=self.configs.detection.upscale)[0] else: bg_img = None self.face_helper.get_inverse_affine(None) # paste each restored face to the input image if self.configs.face_upsample: restored_img = self.face_helper.paste_faces_to_input_image( upsample_img=bg_img, draw_box=draw_box, face_upsampler=self.bg_model, ) else: restored_img = self.face_helper.paste_faces_to_input_image( upsample_img=bg_img, draw_box=draw_box, ) cropped_faces = self.face_helper.cropped_faces return restored_img, restored_faces, cropped_faces if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument( "--save_dir", type=str, default="./save_dir", help="Folder to save the checkpoints and training log", ) parser.add_argument( "--gpu_id", type=str, default='', help="GPU Index, e.g., 025", ) parser.add_argument( "--cfg_path", type=str, default='./configs/sample/iddpm_ffhq256.yaml', help="Path of config files", ) parser.add_argument( "--bs", type=int, default=32, help="Batch size", ) parser.add_argument( "--num_images", type=int, default=3000, help="Number of sampled images", ) parser.add_argument( "--timestep_respacing", type=str, default='1000', help="Sampling steps for accelerate", ) args = parser.parse_args() configs = OmegaConf.load(args.cfg_path) configs.gpu_id = args.gpu_id configs.diffusion.params.timestep_respacing = args.timestep_respacing sampler_dist = DiffusionSampler(configs) sampler_dist.sample_func( bs=args.bs, num_images=args.num_images, save_dir=args.save_dir, )