Spaces:
ginipick
/
Running on Zero

multimodalart's picture
Squashing commit
4450790 verified
raw
history blame
7.26 kB
# Optional face enhance nodes
# region imports
import sys
from pathlib import Path
import comfy.model_management as model_management
import cv2
import insightface
import numpy as np
import onnxruntime
import torch
from insightface.model_zoo.inswapper import INSwapper
from PIL import Image
from ..errors import ModelNotFound
from ..log import NullWriter, mklog
from ..utils import download_antelopev2, get_model_path, pil2tensor, tensor2pil
# endregion
log = mklog(__name__)
class MTB_LoadFaceAnalysisModel:
"""Loads a face analysis model"""
models = []
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"faceswap_model": (
["antelopev2", "buffalo_l", "buffalo_m", "buffalo_sc"],
{"default": "buffalo_l"},
),
},
}
RETURN_TYPES = ("FACE_ANALYSIS_MODEL",)
FUNCTION = "load_model"
CATEGORY = "mtb/facetools"
DEPRECATED = True
def load_model(self, faceswap_model: str):
if faceswap_model == "antelopev2":
download_antelopev2()
face_analyser = insightface.app.FaceAnalysis(
name=faceswap_model,
root=get_model_path("insightface").as_posix(),
)
return (face_analyser,)
class MTB_LoadFaceSwapModel:
"""Loads a faceswap model"""
@staticmethod
def get_models() -> list[Path]:
models_path = get_model_path("insightface")
if models_path.exists():
models = models_path.iterdir()
return [x for x in models if x.suffix in [".onnx", ".pth"]]
return []
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"faceswap_model": (
[x.name for x in cls.get_models()],
{"default": "None"},
),
},
}
RETURN_TYPES = ("FACESWAP_MODEL",)
FUNCTION = "load_model"
CATEGORY = "mtb/facetools"
DEPRECATED = True
def load_model(self, faceswap_model: str):
model_path = get_model_path("insightface", faceswap_model)
if not model_path or not model_path.exists():
raise ModelNotFound(f"{faceswap_model} ({model_path})")
log.info(f"Loading model {model_path}")
return (
INSwapper(
model_path,
onnxruntime.InferenceSession(
path_or_bytes=model_path,
providers=onnxruntime.get_available_providers(),
),
),
)
# region roop node
class MTB_FaceSwap:
"""Face swap using deepinsight/insightface models"""
model = None
model_path = None
def __init__(self) -> None:
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"reference": ("IMAGE",),
"faces_index": ("STRING", {"default": "0"}),
"faceanalysis_model": (
"FACE_ANALYSIS_MODEL",
{"default": "None"},
),
"faceswap_model": ("FACESWAP_MODEL", {"default": "None"}),
},
"optional": {
"preserve_alpha": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "swap"
CATEGORY = "mtb/facetools"
DEPRECATED = True
def swap(
self,
image: torch.Tensor,
reference: torch.Tensor,
faces_index: str,
faceanalysis_model,
faceswap_model,
preserve_alpha=False,
):
def do_swap(img):
model_management.throw_exception_if_processing_interrupted()
img = tensor2pil(img)[0]
ref = tensor2pil(reference)[0]
alpha_channel = None
if preserve_alpha and img.mode == "RGBA":
alpha_channel = img.getchannel("A")
img = img.convert("RGB")
face_ids = {
int(x)
for x in faces_index.strip(",").split(",")
if x.isnumeric()
}
sys.stdout = NullWriter()
swapped = swap_face(
faceanalysis_model, ref, img, faceswap_model, face_ids
)
sys.stdout = sys.__stdout__
if alpha_channel:
swapped.putalpha(alpha_channel)
return pil2tensor(swapped)
batch_count = image.size(0)
log.info(f"Running insightface swap (batch size: {batch_count})")
if reference.size(0) != 1:
raise ValueError("Reference image must have batch size 1")
if batch_count == 1:
image = do_swap(image)
else:
image_batch = [do_swap(image[i]) for i in range(batch_count)]
image = torch.cat(image_batch, dim=0)
return (image,)
# endregion
# region face swap utils
def get_face_single(
face_analyser, img_data: np.ndarray, face_index=0, det_size=(640, 640)
):
face_analyser.prepare(ctx_id=0, det_size=det_size)
face = face_analyser.get(img_data)
if len(face) == 0 and det_size[0] > 320 and det_size[1] > 320:
log.debug("No face ed, trying again with smaller image")
det_size_half = (det_size[0] // 2, det_size[1] // 2)
return get_face_single(
face_analyser,
img_data,
face_index=face_index,
det_size=det_size_half,
)
try:
return sorted(face, key=lambda x: x.bbox[0])[face_index]
except IndexError:
return None
def swap_face(
face_analyser,
source_img: Image.Image | list[Image.Image],
target_img: Image.Image | list[Image.Image],
face_swapper_model,
faces_index: set[int] | None = None,
) -> Image.Image:
if faces_index is None:
faces_index = {0}
log.debug(f"Swapping faces: {faces_index}")
result_image = target_img
if face_swapper_model is not None:
cv_source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR)
cv_target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR)
source_face = get_face_single(
face_analyser, cv_source_img, face_index=0
)
if source_face is not None:
result = cv_target_img
for face_num in faces_index:
target_face = get_face_single(
face_analyser, cv_target_img, face_index=face_num
)
if target_face is not None:
sys.stdout = NullWriter()
result = face_swapper_model.get(
result, target_face, source_face
)
sys.stdout = sys.__stdout__
else:
log.warning(f"No target face found for {face_num}")
result_image = Image.fromarray(
cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
)
else:
log.warning("No source face found")
else:
log.error("No face swap model provided")
return result_image
# endregion face swap utils
__nodes__ = [MTB_FaceSwap, MTB_LoadFaceSwapModel, MTB_LoadFaceAnalysisModel]