import os import tempfile from cog import BasePredictor, Input, Path import shutil from argparse import Namespace import time import sys import pprint import numpy as np from PIL import Image import torch import torchvision.transforms as transforms import dlib sys.path.append(".") sys.path.append("..") from datasets import augmentations from utils.common import tensor2im, log_input_image from models.psp import pSp from scripts.align_all_parallel import align_face class Predictor(BasePredictor): def setup(self): self.predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") model_paths = { "ffhq_frontalize": "pretrained_models/psp_ffhq_frontalization.pt", "celebs_sketch_to_face": "pretrained_models/psp_celebs_sketch_to_face.pt", "celebs_super_resolution": "pretrained_models/psp_celebs_super_resolution.pt", "toonify": "pretrained_models/psp_ffhq_toonify.pt", } loaded_models = {} for key, value in model_paths.items(): loaded_models[key] = torch.load(value, map_location="cpu") self.opts = {} for key, value in loaded_models.items(): self.opts[key] = value["opts"] for key in self.opts.keys(): self.opts[key]["checkpoint_path"] = model_paths[key] if "learn_in_w" not in self.opts[key]: self.opts[key]["learn_in_w"] = False if "output_size" not in self.opts[key]: self.opts[key]["output_size"] = 1024 self.transforms = {} for key in model_paths.keys(): if key in ["ffhq_frontalize", "toonify"]: self.transforms[key] = transforms.Compose( [ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ) elif key == "celebs_sketch_to_face": self.transforms[key] = transforms.Compose( [transforms.Resize((256, 256)), transforms.ToTensor()] ) elif key == "celebs_super_resolution": self.transforms[key] = transforms.Compose( [ transforms.Resize((256, 256)), augmentations.BilinearResize(factors=[16]), transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ) def predict( self, image: Path = Input(description="input image"), model: str = Input( choices=[ "celebs_sketch_to_face", "ffhq_frontalize", "celebs_super_resolution", "toonify", ], description="choose model type", ), ) -> Path: opts = self.opts[model] opts = Namespace(**opts) pprint.pprint(opts) net = pSp(opts) net.eval() net.cuda() print("Model successfully loaded!") original_image = Image.open(str(image)) if opts.label_nc == 0: original_image = original_image.convert("RGB") else: original_image = original_image.convert("L") original_image.resize( (self.opts[model]["output_size"], self.opts[model]["output_size"]) ) # Align Image if model not in ["celebs_sketch_to_face", "celebs_seg_to_face"]: input_image = self.run_alignment(str(image)) else: input_image = original_image img_transforms = self.transforms[model] transformed_image = img_transforms(input_image) if model in ["celebs_sketch_to_face", "celebs_seg_to_face"]: latent_mask = [8, 9, 10, 11, 12, 13, 14, 15, 16, 17] else: latent_mask = None with torch.no_grad(): result_image = run_on_batch( transformed_image.unsqueeze(0), net, latent_mask )[0] input_vis_image = log_input_image(transformed_image, opts) output_image = tensor2im(result_image) if model == "celebs_super_resolution": res = np.concatenate( [ np.array( input_vis_image.resize( ( self.opts[model]["output_size"], self.opts[model]["output_size"], ) ) ), np.array( output_image.resize( ( self.opts[model]["output_size"], self.opts[model]["output_size"], ) ) ), ], axis=1, ) else: res = np.array( output_image.resize( (self.opts[model]["output_size"], self.opts[model]["output_size"]) ) ) out_path = Path(tempfile.mkdtemp()) / "out.png" Image.fromarray(np.array(res)).save(str(out_path)) return out_path def run_alignment(self, image_path): aligned_image = align_face(filepath=image_path, predictor=self.predictor) print("Aligned image has shape: {}".format(aligned_image.size)) return aligned_image def run_on_batch(inputs, net, latent_mask=None): if latent_mask is None: result_batch = net(inputs.to("cuda").float(), randomize_noise=False) else: result_batch = [] for image_idx, input_image in enumerate(inputs): # get latent vector to inject into our input image vec_to_inject = np.random.randn(1, 512).astype("float32") _, latent_to_inject = net( torch.from_numpy(vec_to_inject).to("cuda"), input_code=True, return_latents=True, ) # get output image with injected style vector res = net( input_image.unsqueeze(0).to("cuda").float(), latent_mask=latent_mask, inject_latent=latent_to_inject, resize=False, ) result_batch.append(res) result_batch = torch.cat(result_batch, dim=0) return result_batch