sam-model / predict.py
Denis
lfs
2302223
raw history blame
No virus
3.25 kB
import tempfile
from argparse import Namespace
import dlib
import imageio
import numpy as np
import torch
import torchvision.transforms as transforms
from cog import BasePredictor, Path, Input
from datasets.augmentations import AgeTransformer
from models.psp import pSp
from scripts.align_all_parallel import align_face
from utils.common import tensor2im
class Predictor(BasePredictor):
def setup(self):
self.transform = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
model_path = "pretrained_models/sam_ffhq_aging.pt"
ckpt = torch.load(model_path, map_location="cpu")
opts = ckpt["opts"]
opts["checkpoint_path"] = model_path
opts["device"] = "cuda" if torch.cuda.is_available() else "cpu"
self.opts = Namespace(**opts)
def predict(
self,
image: Path = Input(
description="facial image",
),
target_age: str = Input(
description="age of the output image, when choose 'default' "
"a gif for age from 0, 10, 20,...,to 100 will be displayed",
),
) -> Path:
net = pSp(self.opts)
net.eval()
if torch.cuda.is_available():
net.cuda()
# align image
aligned_image = run_alignment(str(image))
aligned_image.resize((256, 256))
input_image = self.transform(aligned_image)
if target_age == "default":
target_ages = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
age_transformers = [AgeTransformer(target_age=age) for age in target_ages]
else:
age_transformers = [AgeTransformer(target_age=target_age)]
results = np.array(aligned_image.resize((1024, 1024)))
all_imgs = []
for age_transformer in age_transformers:
print(f"Running on target age: {age_transformer.target_age}")
with torch.no_grad():
input_image_age = [age_transformer(input_image.cpu()).to("cuda")]
input_image_age = torch.stack(input_image_age)
result_tensor = run_on_batch(input_image_age, net)[0]
result_image = tensor2im(result_tensor)
all_imgs.append(result_image)
results = np.concatenate([results, result_image], axis=1)
if target_age == "default":
out_path = Path(tempfile.mkdtemp()) / "output.gif"
imageio.mimwrite(str(out_path), all_imgs, duration=0.3)
else:
out_path = Path(tempfile.mkdtemp()) / "output.png"
imageio.imwrite(str(out_path), all_imgs[0])
return out_path
def run_alignment(image_path):
predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
aligned_image = align_face(filepath=image_path, predictor=predictor)
print("Aligned image has shape: {}".format(aligned_image.size))
return aligned_image
def run_on_batch(inputs, net):
result_batch = net(inputs.to("cuda").float(), randomize_noise=False, resize=False)
return result_batch