Face-editor / edit.py
Erwann Millon
initial commit
a23872f
raw history blame
No virus
2.48 kB
import os
import sys
from img_processing import custom_to_pil, preprocess, preprocess_vqgan
sys.path.append("taming-transformers")
import glob
import gradio as gr
import matplotlib.pyplot as plt
import PIL
import taming
import torch
from loaders import load_config
from utils import get_device
def get_embedding(model, path=None, img=None, device="cpu"):
assert path is None or img is None, "Input either path or tensor"
if img is not None:
raise NotImplementedError
x = preprocess(PIL.Image.open(path), target_image_size=256).to(device)
x_processed = preprocess_vqgan(x)
x_latent, _, [_, _, indices] = model.encode(x_processed)
return x_latent
def blend_paths(model, path1, path2, quantize=False, weight=0.5, show=True, device="cuda"):
x = preprocess(PIL.Image.open(path1), target_image_size=256).to(device)
y = preprocess(PIL.Image.open(path2), target_image_size=256).to(device)
x_latent, y_latent = get_embedding(model, path=path1, device=device), get_embedding(model, path=path2, device=device)
z = torch.lerp(x_latent, y_latent, weight)
if quantize:
z = model.quantize(z)[0]
decoded = model.decode(z)[0]
if show:
plt.figure(figsize=(10, 20))
plt.subplot(1, 3, 1)
plt.imshow(x.cpu().permute(0, 2, 3, 1)[0])
plt.subplot(1, 3, 2)
plt.imshow(custom_to_pil(decoded))
plt.subplot(1, 3, 3)
plt.imshow(y.cpu().permute(0, 2, 3, 1)[0])
plt.show()
return custom_to_pil(decoded), z
if __name__ == "__main__":
device = get_device()
# conf_path = "logs/2021-04-23T18-11-19_celebahq_transformer/configs/2021-04-23T18-11-19-project.yaml"
ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
# ckpt_path = "./faceshq/faceshq.pt"
conf_path = "./unwrapped.yaml"
# conf_path = "./faceshq/faceshq.yaml"
config = load_config(conf_path, display=False)
model = taming.models.vqgan.VQModel(**config.model.params)
sd = torch.load("./vqgan_only.pt", map_location="mps")
model.load_state_dict(sd, strict=True)
model.to(device)
blend_paths(model, "./test_data/face.jpeg", "./test_data/face2.jpeg", quantize=False, weight=.5)
plt.show()
demo = gr.Interface(
get_image,
inputs=gr.inputs.Image(label="UploadZz a black and white face", type="filepath"),
outputs="image",
title="Upload a black and white face and get a colorized image!",
)