File size: 5,446 Bytes
a1b524b 5ed3dd9 a1b524b dda8135 5ed3dd9 dda8135 5ed3dd9 a1b524b dda8135 5ed3dd9 a1b524b 26fa884 dda8135 a1b524b dda8135 a1b524b dda8135 a1b524b dda8135 a1b524b dda8135 a1b524b 26fa884 a1b524b dda8135 a1b524b dda8135 a1b524b dda8135 a1b524b dda8135 a1b524b dda8135 a1b524b dda8135 a1b524b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from __future__ import annotations
import argparse
import os
import pathlib
import subprocess
import sys
from typing import Callable, Union
import dlib
import huggingface_hub
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
import torchvision.transforms as T
if os.getenv("SYSTEM") == "spaces" and not torch.cuda.is_available():
with open("patch.e4e") as f:
subprocess.run("patch -p1".split(), cwd="encoder4editing", stdin=f)
with open("patch.hairclip") as f:
subprocess.run("patch -p1".split(), cwd="HairCLIP", stdin=f)
app_dir = pathlib.Path(__file__).parent
e4e_dir = app_dir / "encoder4editing"
sys.path.insert(0, e4e_dir.as_posix())
from models.psp import pSp
from utils.alignment import align_face
hairclip_dir = app_dir / "HairCLIP"
mapper_dir = hairclip_dir / "mapper"
sys.path.insert(0, hairclip_dir.as_posix())
sys.path.insert(0, mapper_dir.as_posix())
from mapper.datasets.latents_dataset_inference import LatentsDatasetInference
from mapper.hairclip_mapper import HairCLIPMapper
class Model:
def __init__(self):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.landmark_model = self._create_dlib_landmark_model()
self.e4e = self._load_e4e()
self.hairclip = self._load_hairclip()
self.transform = self._create_transform()
@staticmethod
def _create_dlib_landmark_model():
path = huggingface_hub.hf_hub_download(
"public-data/dlib_face_landmark_model", "shape_predictor_68_face_landmarks.dat"
)
return dlib.shape_predictor(path)
def _load_e4e(self) -> nn.Module:
ckpt_path = huggingface_hub.hf_hub_download("public-data/e4e", "e4e_ffhq_encode.pt")
ckpt = torch.load(ckpt_path, map_location="cpu")
opts = ckpt["opts"]
opts["device"] = self.device.type
opts["checkpoint_path"] = ckpt_path
opts = argparse.Namespace(**opts)
model = pSp(opts)
model.to(self.device)
model.eval()
return model
def _load_hairclip(self) -> nn.Module:
ckpt_path = huggingface_hub.hf_hub_download("public-data/HairCLIP", "hairclip.pt")
ckpt = torch.load(ckpt_path, map_location="cpu")
opts = ckpt["opts"]
opts["device"] = self.device.type
opts["checkpoint_path"] = ckpt_path
opts["editing_type"] = "both"
opts["input_type"] = "text"
opts["hairstyle_description"] = "HairCLIP/mapper/hairstyle_list.txt"
opts["color_description"] = "red"
opts = argparse.Namespace(**opts)
model = HairCLIPMapper(opts)
model.to(self.device)
model.eval()
return model
@staticmethod
def _create_transform() -> Callable:
transform = T.Compose(
[
T.Resize(256),
T.CenterCrop(256),
T.ToTensor(),
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
return transform
def detect_and_align_face(self, image: str) -> PIL.Image.Image:
image = align_face(filepath=image, predictor=self.landmark_model)
return image
@staticmethod
def denormalize(tensor: torch.Tensor) -> torch.Tensor:
return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8)
def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
tensor = self.denormalize(tensor)
return tensor.cpu().numpy().transpose(1, 2, 0)
@torch.inference_mode()
def reconstruct_face(self, image: PIL.Image.Image) -> tuple[np.ndarray, torch.Tensor]:
input_data = self.transform(image).unsqueeze(0).to(self.device)
reconstructed_images, latents = self.e4e(input_data, randomize_noise=False, return_latents=True)
reconstructed = torch.clamp(reconstructed_images[0].detach(), -1, 1)
reconstructed = self.postprocess(reconstructed)
return reconstructed, latents[0]
@torch.inference_mode()
def generate(
self, editing_type: str, hairstyle_index: int, color_description: str, latent: torch.Tensor
) -> np.ndarray:
opts = self.hairclip.opts
opts.editing_type = editing_type
opts.color_description = color_description
if editing_type == "color":
hairstyle_index = 0
device = torch.device(opts.device)
dataset = LatentsDatasetInference(latents=latent.unsqueeze(0).cpu(), opts=opts)
w, hairstyle_text_inputs_list, color_text_inputs_list = dataset[0][:3]
w = w.unsqueeze(0).to(device)
hairstyle_text_inputs = hairstyle_text_inputs_list[hairstyle_index].unsqueeze(0).to(device)
color_text_inputs = color_text_inputs_list[0].unsqueeze(0).to(device)
hairstyle_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).to(device)
color_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).to(device)
w_hat = w + 0.1 * self.hairclip.mapper(
w,
hairstyle_text_inputs,
color_text_inputs,
hairstyle_tensor_hairmasked,
color_tensor_hairmasked,
)
x_hat, _ = self.hairclip.decoder(
[w_hat],
input_is_latent=True,
return_latents=True,
randomize_noise=False,
truncation=1,
)
res = torch.clamp(x_hat[0].detach(), -1, 1)
res = self.postprocess(res)
return res
|