File size: 3,252 Bytes
2302223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import tempfile
from argparse import Namespace
import dlib
import imageio
import numpy as np
import torch
import torchvision.transforms as transforms
from cog import BasePredictor, Path, Input

from datasets.augmentations import AgeTransformer
from models.psp import pSp
from scripts.align_all_parallel import align_face
from utils.common import tensor2im


class Predictor(BasePredictor):
    def setup(self):
        self.transform = transforms.Compose(
            [
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]
        )
        model_path = "pretrained_models/sam_ffhq_aging.pt"
        ckpt = torch.load(model_path, map_location="cpu")

        opts = ckpt["opts"]
        opts["checkpoint_path"] = model_path
        opts["device"] = "cuda" if torch.cuda.is_available() else "cpu"

        self.opts = Namespace(**opts)

    def predict(
            self,
            image: Path = Input(
                description="facial image",
            ),
            target_age: str = Input(
                description="age of the output image, when choose 'default' "
                            "a gif for age from 0, 10, 20,...,to 100 will be displayed",
            ),
    ) -> Path:
        net = pSp(self.opts)
        net.eval()
        if torch.cuda.is_available():
            net.cuda()

        # align image
        aligned_image = run_alignment(str(image))
        aligned_image.resize((256, 256))

        input_image = self.transform(aligned_image)

        if target_age == "default":
            target_ages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
            age_transformers = [AgeTransformer(target_age=age) for age in target_ages]
        else:
            age_transformers = [AgeTransformer(target_age=target_age)]

        results = np.array(aligned_image.resize((1024, 1024)))
        all_imgs = []
        for age_transformer in age_transformers:
            print(f"Running on target age: {age_transformer.target_age}")
            with torch.no_grad():
                input_image_age = [age_transformer(input_image.cpu()).to("cuda")]
                input_image_age = torch.stack(input_image_age)
                result_tensor = run_on_batch(input_image_age, net)[0]
                result_image = tensor2im(result_tensor)
                all_imgs.append(result_image)
                results = np.concatenate([results, result_image], axis=1)

        if target_age == "default":
            out_path = Path(tempfile.mkdtemp()) / "output.gif"
            imageio.mimwrite(str(out_path), all_imgs, duration=0.3)
        else:
            out_path = Path(tempfile.mkdtemp()) / "output.png"
            imageio.imwrite(str(out_path), all_imgs[0])
        return out_path


def run_alignment(image_path):
    predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
    aligned_image = align_face(filepath=image_path, predictor=predictor)
    print("Aligned image has shape: {}".format(aligned_image.size))
    return aligned_image


def run_on_batch(inputs, net):
    result_batch = net(inputs.to("cuda").float(), randomize_noise=False, resize=False)
    return result_batch