import os import torch from facexlib.utils.face_restoration_helper import FaceRestoreHelper from gfpgan import GFPGANv1Clean, GFPGANer from torch.hub import get_dir class MyGFPGANer(GFPGANer): """Helper for restoration with GFPGAN. It will detect and crop faces, and then resize the faces to 512x512. GFPGAN is used to restored the resized faces. The background is upsampled with the bg_upsampler. Finally, the faces will be pasted back to the upsample background image. Args: model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically). upscale (float): The upscale of the final output. Default: 2. arch (str): The GFPGAN architecture. Option: clean | original. Default: clean. channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. bg_upsampler (nn.Module): The upsampler for the background. Default: None. """ def __init__( self, model_path, upscale=2, arch="clean", channel_multiplier=2, bg_upsampler=None, device=None, ): self.upscale = upscale self.bg_upsampler = bg_upsampler # initialize model self.device = ( torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device ) # initialize the GFP-GAN if arch == "clean": self.gfpgan = GFPGANv1Clean( out_size=512, num_style_feat=512, channel_multiplier=channel_multiplier, decoder_load_path=None, fix_decoder=False, num_mlp=8, input_is_latent=True, different_w=True, narrow=1, sft_half=True, ) elif arch == "RestoreFormer": from gfpgan.archs.restoreformer_arch import RestoreFormer self.gfpgan = RestoreFormer() hub_dir = get_dir() model_dir = os.path.join(hub_dir, "checkpoints") # initialize face helper self.face_helper = FaceRestoreHelper( upscale, face_size=512, crop_ratio=(1, 1), det_model="retinaface_resnet50", save_ext="png", use_parse=True, device=self.device, model_rootpath=model_dir, ) loadnet = torch.load(model_path) if "params_ema" in loadnet: keyname = "params_ema" else: keyname = "params" self.gfpgan.load_state_dict(loadnet[keyname], strict=True) self.gfpgan.eval() self.gfpgan = self.gfpgan.to(self.device)