File size: 2,483 Bytes
a23872f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import sys

from img_processing import custom_to_pil, preprocess, preprocess_vqgan

sys.path.append("taming-transformers")
import glob

import gradio as gr
import matplotlib.pyplot as plt
import PIL
import taming
import torch

from loaders import load_config
from utils import get_device


def get_embedding(model, path=None, img=None, device="cpu"):
    assert path is None or img is None, "Input either path or tensor"
    if img is not None:
        raise NotImplementedError
    x = preprocess(PIL.Image.open(path), target_image_size=256).to(device)
    x_processed = preprocess_vqgan(x)
    x_latent, _, [_, _, indices] = model.encode(x_processed)
    return x_latent

    
def blend_paths(model, path1, path2, quantize=False, weight=0.5, show=True, device="cuda"):
    x = preprocess(PIL.Image.open(path1), target_image_size=256).to(device)
    y = preprocess(PIL.Image.open(path2), target_image_size=256).to(device)
    x_latent, y_latent = get_embedding(model, path=path1, device=device), get_embedding(model, path=path2, device=device)
    z = torch.lerp(x_latent, y_latent, weight)
    if quantize:
        z = model.quantize(z)[0]
    decoded = model.decode(z)[0]
    if show:
        plt.figure(figsize=(10, 20))
        plt.subplot(1, 3, 1)
        plt.imshow(x.cpu().permute(0, 2, 3, 1)[0])
        plt.subplot(1, 3, 2)
        plt.imshow(custom_to_pil(decoded))
        plt.subplot(1, 3, 3)
        plt.imshow(y.cpu().permute(0, 2, 3, 1)[0])
        plt.show()
    return custom_to_pil(decoded), z

if __name__ == "__main__":
    device = get_device()
    # conf_path = "logs/2021-04-23T18-11-19_celebahq_transformer/configs/2021-04-23T18-11-19-project.yaml"
    ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
    # ckpt_path = "./faceshq/faceshq.pt"
    conf_path = "./unwrapped.yaml"
    # conf_path = "./faceshq/faceshq.yaml"
    config = load_config(conf_path, display=False)
    model = taming.models.vqgan.VQModel(**config.model.params)
    sd = torch.load("./vqgan_only.pt", map_location="mps")
    model.load_state_dict(sd, strict=True)
    model.to(device)
    blend_paths(model, "./test_data/face.jpeg", "./test_data/face2.jpeg", quantize=False, weight=.5)
    plt.show()

    demo = gr.Interface(
        get_image,
        inputs=gr.inputs.Image(label="UploadZz a black and white face", type="filepath"),
        outputs="image",
        title="Upload a black and white face and get a colorized image!",
    )