|
import os |
|
|
|
import comfy |
|
import comfy.utils |
|
import cv2 |
|
import folder_paths |
|
import numpy as np |
|
import torch |
|
from comfy import model_management |
|
from PIL import Image |
|
|
|
from ..log import NullWriter, log |
|
from ..utils import get_model_path, np2tensor, pil2tensor, tensor2np |
|
|
|
|
|
class MTB_LoadFaceEnhanceModel: |
|
"""Loads a GFPGan or RestoreFormer model for face enhancement.""" |
|
|
|
def __init__(self) -> None: |
|
pass |
|
|
|
@classmethod |
|
def get_models_root(cls): |
|
fr = get_model_path("face_restore") |
|
|
|
if fr.exists(): |
|
return (fr, None) |
|
|
|
um = get_model_path("upscale_models") |
|
return (fr, um) if um.exists() else (None, None) |
|
|
|
@classmethod |
|
def get_models(cls): |
|
fr_models_path, um_models_path = cls.get_models_root() |
|
|
|
if fr_models_path is None and um_models_path is None: |
|
if not hasattr(cls, "_warned"): |
|
log.warning("Face restoration models not found.") |
|
cls._warned = True |
|
return [] |
|
if not fr_models_path.exists(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if um_models_path.exists(): |
|
return [ |
|
x |
|
for x in um_models_path.iterdir() |
|
if x.name.endswith(".pth") |
|
and ("GFPGAN" in x.name or "RestoreFormer" in x.name) |
|
] |
|
return [] |
|
|
|
return [ |
|
x |
|
for x in fr_models_path.iterdir() |
|
if x.name.endswith(".pth") |
|
and ("GFPGAN" in x.name or "RestoreFormer" in x.name) |
|
] |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"model_name": ( |
|
[x.name for x in cls.get_models()], |
|
{"default": "None"}, |
|
), |
|
"upscale": ("INT", {"default": 1}), |
|
}, |
|
"optional": {"bg_upsampler": ("UPSCALE_MODEL", {"default": None})}, |
|
} |
|
|
|
RETURN_TYPES = ("FACEENHANCE_MODEL",) |
|
RETURN_NAMES = ("model",) |
|
FUNCTION = "load_model" |
|
CATEGORY = "mtb/facetools" |
|
DEPRECATED = True |
|
|
|
def load_model(self, model_name, upscale=2, bg_upsampler=None): |
|
from gfpgan import GFPGANer |
|
|
|
basic = "RestoreFormer" not in model_name |
|
|
|
fr_root, um_root = self.get_models_root() |
|
|
|
if bg_upsampler is not None: |
|
log.warning( |
|
f"Upscale value overridden to {bg_upsampler.scale} from bg_upsampler" |
|
) |
|
upscale = bg_upsampler.scale |
|
bg_upsampler = BGUpscaleWrapper(bg_upsampler) |
|
|
|
sys.stdout = NullWriter() |
|
model = GFPGANer( |
|
model_path=( |
|
(fr_root if fr_root.exists() else um_root) / model_name |
|
).as_posix(), |
|
upscale=upscale, |
|
arch="clean" |
|
if basic |
|
else "RestoreFormer", |
|
channel_multiplier=2, |
|
bg_upsampler=bg_upsampler, |
|
) |
|
|
|
sys.stdout = sys.__stdout__ |
|
return (model,) |
|
|
|
|
|
class BGUpscaleWrapper: |
|
def __init__(self, upscale_model) -> None: |
|
self.upscale_model = upscale_model |
|
|
|
def enhance(self, img: Image.Image, outscale=2): |
|
device = model_management.get_torch_device() |
|
self.upscale_model.to(device) |
|
|
|
tile = 128 + 64 |
|
overlap = 8 |
|
|
|
imgt = np2tensor(img) |
|
imgt = imgt.movedim(-1, -3).to(device) |
|
|
|
steps = imgt.shape[0] * comfy.utils.get_tiled_scale_steps( |
|
imgt.shape[3], |
|
imgt.shape[2], |
|
tile_x=tile, |
|
tile_y=tile, |
|
overlap=overlap, |
|
) |
|
|
|
log.debug(f"Steps: {steps}") |
|
|
|
pbar = comfy.utils.ProgressBar(steps) |
|
|
|
s = comfy.utils.tiled_scale( |
|
imgt, |
|
lambda a: self.upscale_model(a), |
|
tile_x=tile, |
|
tile_y=tile, |
|
overlap=overlap, |
|
upscale_amount=self.upscale_model.scale, |
|
pbar=pbar, |
|
) |
|
|
|
self.upscale_model.cpu() |
|
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0) |
|
return (tensor2np(s)[0],) |
|
|
|
|
|
import sys |
|
|
|
|
|
class MTB_RestoreFace: |
|
"""Uses GFPGan to restore faces""" |
|
|
|
def __init__(self) -> None: |
|
pass |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "restore" |
|
CATEGORY = "mtb/facetools" |
|
DEPRECATED = True |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"image": ("IMAGE",), |
|
"model": ("FACEENHANCE_MODEL",), |
|
|
|
"aligned": ("BOOLEAN", {"default": False}), |
|
|
|
"only_center_face": ("BOOLEAN", {"default": False}), |
|
|
|
"weight": ("FLOAT", {"default": 0.5}), |
|
"save_tmp_steps": ("BOOLEAN", {"default": True}), |
|
}, |
|
"optional": { |
|
"preserve_alpha": ("BOOLEAN", {"default": True}), |
|
}, |
|
} |
|
|
|
def do_restore( |
|
self, |
|
image: torch.Tensor, |
|
model, |
|
aligned, |
|
only_center_face, |
|
weight, |
|
save_tmp_steps, |
|
preserve_alpha: bool = False, |
|
) -> torch.Tensor: |
|
pimage = tensor2np(image)[0] |
|
width, height = pimage.shape[1], pimage.shape[0] |
|
source_img = cv2.cvtColor(np.array(pimage), cv2.COLOR_RGB2BGR) |
|
|
|
alpha_channel = None |
|
if ( |
|
preserve_alpha and image.size(-1) == 4 |
|
): |
|
alpha_channel = pimage[:, :, 3] |
|
pimage = pimage[:, :, :3] |
|
|
|
sys.stdout = NullWriter() |
|
cropped_faces, restored_faces, restored_img = model.enhance( |
|
source_img, |
|
has_aligned=aligned, |
|
only_center_face=only_center_face, |
|
paste_back=True, |
|
|
|
weight=weight, |
|
) |
|
sys.stdout = sys.__stdout__ |
|
log.warning(f"Weight value has no effect for now. (value: {weight})") |
|
|
|
if save_tmp_steps: |
|
self.save_intermediate_images( |
|
cropped_faces, restored_faces, height, width |
|
) |
|
output = None |
|
if restored_img is not None: |
|
restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB) |
|
output = Image.fromarray(restored_img) |
|
|
|
if alpha_channel is not None: |
|
alpha_resized = Image.fromarray(alpha_channel).resize( |
|
output.size, Image.LANCZOS |
|
) |
|
output.putalpha(alpha_resized) |
|
|
|
|
|
return pil2tensor(output) |
|
|
|
def restore( |
|
self, |
|
image: torch.Tensor, |
|
model, |
|
aligned=False, |
|
only_center_face=False, |
|
weight=0.5, |
|
save_tmp_steps=True, |
|
preserve_alpha: bool = False, |
|
) -> tuple[torch.Tensor]: |
|
out = [ |
|
self.do_restore( |
|
image[i], |
|
model, |
|
aligned, |
|
only_center_face, |
|
weight, |
|
save_tmp_steps, |
|
preserve_alpha, |
|
) |
|
for i in range(image.size(0)) |
|
] |
|
|
|
return (torch.cat(out, dim=0),) |
|
|
|
def get_step_image_path(self, step, idx): |
|
( |
|
full_output_folder, |
|
filename, |
|
counter, |
|
_subfolder, |
|
_filename_prefix, |
|
) = folder_paths.get_save_image_path( |
|
f"{step}_{idx:03}", |
|
folder_paths.temp_directory, |
|
) |
|
file = f"{filename}_{counter:05}_.png" |
|
|
|
return os.path.join(full_output_folder, file) |
|
|
|
def save_intermediate_images( |
|
self, cropped_faces, restored_faces, height, width |
|
): |
|
for idx, (cropped_face, restored_face) in enumerate( |
|
zip(cropped_faces, restored_faces, strict=False) |
|
): |
|
face_id = idx + 1 |
|
file = self.get_step_image_path("cropped_faces", face_id) |
|
cv2.imwrite(file, cropped_face) |
|
|
|
file = self.get_step_image_path("cropped_faces_restored", face_id) |
|
cv2.imwrite(file, restored_face) |
|
|
|
file = self.get_step_image_path("cropped_faces_compare", face_id) |
|
|
|
|
|
cmp_img = np.concatenate((cropped_face, restored_face), axis=1) |
|
cv2.imwrite(file, cmp_img) |
|
|
|
|
|
__nodes__ = [MTB_RestoreFace, MTB_LoadFaceEnhanceModel] |
|
|