Spaces:
Build error
Build error
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from model import Model as Model | |
| from annotated_directions import annotated_directions | |
| device = torch.device('cpu') | |
| torch.set_grad_enabled(False) | |
| model_name = "stylegan2_ffhq1024" | |
| directions = list(annotated_directions[model_name].keys()) | |
| def inference(seed, direction): | |
| layer = annotated_directions[model_name][direction]['layer'] | |
| M = Model(model_name, trunc_psi=1.0, device=device, layer=layer) | |
| M.ranks = annotated_directions[model_name][direction]['ranks'] | |
| # load the checkpoint | |
| try: | |
| M.Us = torch.Tensor(np.load(annotated_directions[model_name][direction]['checkpoints_path'][0])).to(device) | |
| M.Uc = torch.Tensor(np.load(annotated_directions[model_name][direction]['checkpoints_path'][1])).to(device) | |
| except KeyError: | |
| raise KeyError('ERROR: No directions specified in ./annotated_directions.py for this model') | |
| part, appearance, lam = annotated_directions[model_name][direction]['parameters'] | |
| Z, image, image2, part_img = M.edit_at_layer([[part]], [appearance], [lam], t=seed, Uc=M.Uc, Us=M.Us, noise=None) | |
| dif = np.tile(((np.mean((image - image2)**2, -1)))[:,:,None], [1,1,3]).astype(np.uint8) | |
| return Image.fromarray(np.concatenate([image, image2, dif], 1)) | |
| demo = gr.Interface( | |
| fn=inference, | |
| inputs=[gr.Slider(0, 1000, value=64), gr.Dropdown(directions, value='no_eyebrows')], | |
| outputs=[gr.Image(type="pil", value='./default.png', label="original | edited | mean-squared difference")], | |
| title="PandA (ICLR'23) - FFHQ edit zoo", | |
| description="Provides a quick interface to manipulate pre-annotated directions with pre-trained global parts and appearances factors. Note that we use the free CPU tier, so synthesis takes about 10 seconds.", | |
| article="Check out the full demo and paper at: <a href='https://github.com/james-oldfield/PandA'>https://github.com/james-oldfield/PandA</a>" | |
| ) | |
| demo.launch() |