|
import tempfile |
|
from argparse import Namespace |
|
import dlib |
|
import imageio |
|
import numpy as np |
|
import torch |
|
import torchvision.transforms as transforms |
|
from cog import BasePredictor, Path, Input |
|
|
|
from datasets.augmentations import AgeTransformer |
|
from models.psp import pSp |
|
from scripts.align_all_parallel import align_face |
|
from utils.common import tensor2im |
|
|
|
|
|
class Predictor(BasePredictor): |
|
def setup(self): |
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
|
] |
|
) |
|
model_path = "pretrained_models/sam_ffhq_aging.pt" |
|
ckpt = torch.load(model_path, map_location="cpu") |
|
|
|
opts = ckpt["opts"] |
|
opts["checkpoint_path"] = model_path |
|
opts["device"] = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
self.opts = Namespace(**opts) |
|
|
|
def predict( |
|
self, |
|
image: Path = Input( |
|
description="facial image", |
|
), |
|
target_age: str = Input( |
|
description="age of the output image, when choose 'default' " |
|
"a gif for age from 0, 10, 20,...,to 100 will be displayed", |
|
), |
|
) -> Path: |
|
net = pSp(self.opts) |
|
net.eval() |
|
if torch.cuda.is_available(): |
|
net.cuda() |
|
|
|
|
|
aligned_image = run_alignment(str(image)) |
|
aligned_image.resize((256, 256)) |
|
|
|
input_image = self.transform(aligned_image) |
|
|
|
if target_age == "default": |
|
target_ages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] |
|
age_transformers = [AgeTransformer(target_age=age) for age in target_ages] |
|
else: |
|
age_transformers = [AgeTransformer(target_age=target_age)] |
|
|
|
results = np.array(aligned_image.resize((1024, 1024))) |
|
all_imgs = [] |
|
for age_transformer in age_transformers: |
|
print(f"Running on target age: {age_transformer.target_age}") |
|
with torch.no_grad(): |
|
input_image_age = [age_transformer(input_image.cpu()).to("cuda")] |
|
input_image_age = torch.stack(input_image_age) |
|
result_tensor = run_on_batch(input_image_age, net)[0] |
|
result_image = tensor2im(result_tensor) |
|
all_imgs.append(result_image) |
|
results = np.concatenate([results, result_image], axis=1) |
|
|
|
if target_age == "default": |
|
out_path = Path(tempfile.mkdtemp()) / "output.gif" |
|
imageio.mimwrite(str(out_path), all_imgs, duration=0.3) |
|
else: |
|
out_path = Path(tempfile.mkdtemp()) / "output.png" |
|
imageio.imwrite(str(out_path), all_imgs[0]) |
|
return out_path |
|
|
|
|
|
def run_alignment(image_path): |
|
predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") |
|
aligned_image = align_face(filepath=image_path, predictor=predictor) |
|
print("Aligned image has shape: {}".format(aligned_image.size)) |
|
return aligned_image |
|
|
|
|
|
def run_on_batch(inputs, net): |
|
result_batch = net(inputs.to("cuda").float(), randomize_noise=False, resize=False) |
|
return result_batch |
|
|