|
import os |
|
|
|
import torch |
|
from facexlib.utils.face_restoration_helper import FaceRestoreHelper |
|
from gfpgan import GFPGANv1Clean, GFPGANer |
|
from torch.hub import get_dir |
|
|
|
|
|
class MyGFPGANer(GFPGANer): |
|
"""Helper for restoration with GFPGAN. |
|
|
|
It will detect and crop faces, and then resize the faces to 512x512. |
|
GFPGAN is used to restored the resized faces. |
|
The background is upsampled with the bg_upsampler. |
|
Finally, the faces will be pasted back to the upsample background image. |
|
|
|
Args: |
|
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically). |
|
upscale (float): The upscale of the final output. Default: 2. |
|
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean. |
|
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. |
|
bg_upsampler (nn.Module): The upsampler for the background. Default: None. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model_path, |
|
upscale=2, |
|
arch="clean", |
|
channel_multiplier=2, |
|
bg_upsampler=None, |
|
device=None, |
|
): |
|
self.upscale = upscale |
|
self.bg_upsampler = bg_upsampler |
|
|
|
|
|
self.device = ( |
|
torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
if device is None |
|
else device |
|
) |
|
|
|
if arch == "clean": |
|
self.gfpgan = GFPGANv1Clean( |
|
out_size=512, |
|
num_style_feat=512, |
|
channel_multiplier=channel_multiplier, |
|
decoder_load_path=None, |
|
fix_decoder=False, |
|
num_mlp=8, |
|
input_is_latent=True, |
|
different_w=True, |
|
narrow=1, |
|
sft_half=True, |
|
) |
|
elif arch == "RestoreFormer": |
|
from gfpgan.archs.restoreformer_arch import RestoreFormer |
|
|
|
self.gfpgan = RestoreFormer() |
|
|
|
hub_dir = get_dir() |
|
model_dir = os.path.join(hub_dir, "checkpoints") |
|
|
|
|
|
self.face_helper = FaceRestoreHelper( |
|
upscale, |
|
face_size=512, |
|
crop_ratio=(1, 1), |
|
det_model="retinaface_resnet50", |
|
save_ext="png", |
|
use_parse=True, |
|
device=self.device, |
|
model_rootpath=model_dir, |
|
) |
|
|
|
loadnet = torch.load(model_path) |
|
if "params_ema" in loadnet: |
|
keyname = "params_ema" |
|
else: |
|
keyname = "params" |
|
self.gfpgan.load_state_dict(loadnet[keyname], strict=True) |
|
self.gfpgan.eval() |
|
self.gfpgan = self.gfpgan.to(self.device) |
|
|