File size: 1,969 Bytes
ceb80dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32b433e
ceb80dd
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
import gradio as gr
import torch
from PIL import Image
from model import Model as Model
from annotated_directions import annotated_directions
device = torch.device('cpu')

torch.set_grad_enabled(False)
model_name = "stylegan2_ffhq1024"

directions = list(annotated_directions[model_name].keys())


def inference(seed, direction):
    layer = annotated_directions[model_name][direction]['layer']
    M = Model(model_name, trunc_psi=1.0, device=device, layer=layer)
    M.ranks = annotated_directions[model_name][direction]['ranks']

    # load the checkpoint
    try:
        M.Us = torch.Tensor(np.load(annotated_directions[model_name][direction]['checkpoints_path'][0])).to(device)
        M.Uc = torch.Tensor(np.load(annotated_directions[model_name][direction]['checkpoints_path'][1])).to(device)
    except KeyError:
        raise KeyError('ERROR: No directions specified in ./annotated_directions.py for this model')

    part, appearance, lam = annotated_directions[model_name][direction]['parameters']

    Z, image, image2, part_img = M.edit_at_layer([[part]], [appearance], [lam], t=seed, Uc=M.Uc, Us=M.Us, noise=None)

    dif = np.tile(((np.mean((image - image2)**2, -1)))[:,:,None], [1,1,3]).astype(np.uint8)

    return Image.fromarray(np.concatenate([image, image2, dif], 1))


demo = gr.Interface(
    fn=inference,
    inputs=[gr.Slider(0, 1000, value=64), gr.Dropdown(directions, value='no_eyebrows')],
    outputs=[gr.Image(type="pil", value='./default.png', label="original | edited | mean-squared difference")],
    title="PandA (ICLR'23) - FFHQ edit zoo",
    description="Provides a quick interface to manipulate pre-annotated directions with pre-trained global parts and appearances factors. Note that we use the free CPU tier, so synthesis takes about 10 seconds.",
    article="Check out the full demo and paper at: <a href='https://github.com/james-oldfield/PandA'>https://github.com/james-oldfield/PandA</a>"
)
demo.launch()