File size: 1,786 Bytes
a23872f
 
 
 
 
 
 
 
 
 
 
 
 
 
e0f92a0
a23872f
 
 
 
eac223c
a23872f
 
 
 
eac223c
 
a23872f
e0f92a0
 
 
 
a23872f
 
e0f92a0
 
a23872f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0f92a0
a23872f
 
e0f92a0
a23872f
e0f92a0
 
82e6d22
 
e0f92a0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import sys

from img_processing import custom_to_pil, preprocess, preprocess_vqgan

sys.path.append("taming-transformers")
import glob

import gradio as gr
import matplotlib.pyplot as plt
import PIL
import taming
import torch

from loaders import load_config, load_default
from utils import get_device


def get_embedding(model, path=None, img=None, device="cpu"):
    assert path or img, "Input either path or tensor"
    if img is not None:
        raise NotImplementedError
    x = preprocess(PIL.Image.open(path), target_image_size=256).to(device)
    x_processed = preprocess_vqgan(x)
    z, _, [_, _, indices] = model.encode(x_processed)
    return z


def blend_paths(
    model, path1, path2, quantize=False, weight=0.5, show=True, device="cuda"
):
    x = preprocess(PIL.Image.open(path1), target_image_size=256).to(device)
    y = preprocess(PIL.Image.open(path2), target_image_size=256).to(device)
    x_latent = get_embedding(model, path=path1, device=device)
    y_latent = get_embedding(model, path=path2, device=device)
    z = torch.lerp(x_latent, y_latent, weight)
    if quantize:
        z = model.quantize(z)[0]
    decoded = model.decode(z)[0]
    if show:
        plt.figure(figsize=(10, 20))
        plt.subplot(1, 3, 1)
        plt.imshow(x.cpu().permute(0, 2, 3, 1)[0])
        plt.subplot(1, 3, 2)
        plt.imshow(custom_to_pil(decoded))
        plt.subplot(1, 3, 3)
        plt.imshow(y.cpu().permute(0, 2, 3, 1)[0])
        plt.show()
    return custom_to_pil(decoded), z


if __name__ == "__main__":
    device = get_device()
    model = load_default(device)
    model.to(device)
    blend_paths(
        model,
        "./test_pics/face.jpeg",
        "./test_pics/face2.jpeg",
        quantize=False,
        weight=0.5,
    )
    plt.show()