Spaces:
Build error
Build error
File size: 1,969 Bytes
ceb80dd 32b433e ceb80dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import numpy as np
import gradio as gr
import torch
from PIL import Image
from model import Model as Model
from annotated_directions import annotated_directions
device = torch.device('cpu')
torch.set_grad_enabled(False)
model_name = "stylegan2_ffhq1024"
directions = list(annotated_directions[model_name].keys())
def inference(seed, direction):
layer = annotated_directions[model_name][direction]['layer']
M = Model(model_name, trunc_psi=1.0, device=device, layer=layer)
M.ranks = annotated_directions[model_name][direction]['ranks']
# load the checkpoint
try:
M.Us = torch.Tensor(np.load(annotated_directions[model_name][direction]['checkpoints_path'][0])).to(device)
M.Uc = torch.Tensor(np.load(annotated_directions[model_name][direction]['checkpoints_path'][1])).to(device)
except KeyError:
raise KeyError('ERROR: No directions specified in ./annotated_directions.py for this model')
part, appearance, lam = annotated_directions[model_name][direction]['parameters']
Z, image, image2, part_img = M.edit_at_layer([[part]], [appearance], [lam], t=seed, Uc=M.Uc, Us=M.Us, noise=None)
dif = np.tile(((np.mean((image - image2)**2, -1)))[:,:,None], [1,1,3]).astype(np.uint8)
return Image.fromarray(np.concatenate([image, image2, dif], 1))
demo = gr.Interface(
fn=inference,
inputs=[gr.Slider(0, 1000, value=64), gr.Dropdown(directions, value='no_eyebrows')],
outputs=[gr.Image(type="pil", value='./default.png', label="original | edited | mean-squared difference")],
title="PandA (ICLR'23) - FFHQ edit zoo",
description="Provides a quick interface to manipulate pre-annotated directions with pre-trained global parts and appearances factors. Note that we use the free CPU tier, so synthesis takes about 10 seconds.",
article="Check out the full demo and paper at: <a href='https://github.com/james-oldfield/PandA'>https://github.com/james-oldfield/PandA</a>"
)
demo.launch() |