Spaces:
Build error
Build error
james-oldfield
commited on
Commit
•
ceb80dd
1
Parent(s):
2a76164
Upload 18 files
Browse files- annotated_directions.py +171 -0
- app.py +44 -0
- checkpoints/Uc-name_stylegan2_afhqdog512-layer_8-rank_512.npy +3 -0
- checkpoints/Uc-name_stylegan2_ffhq1024-layer_5-rank_512.npy +3 -0
- checkpoints/Uc-name_stylegan2_ffhq1024-layer_6-rank_512.npy +3 -0
- checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy +3 -0
- checkpoints/Us-name_stylegan2_afhqdog512-layer_8-rank_8.npy +3 -0
- checkpoints/Us-name_stylegan2_ffhq1024-layer_5-rank_8.npy +3 -0
- checkpoints/Us-name_stylegan2_ffhq1024-layer_6-rank_16.npy +3 -0
- checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy +3 -0
- demo.ipynb +0 -0
- environment.yml +25 -0
- ffhq-edit-zoo.ipynb +0 -0
- localize-concepts.ipynb +0 -0
- model.py +510 -0
- readme.md +82 -0
- requirements.txt +14 -0
- utils.py +110 -0
annotated_directions.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
annotated_directions = {
|
2 |
+
'stylegan2_ffhq1024': {
|
3 |
+
# Directions used in paper with a single decomposition:
|
4 |
+
'big_eyes': {
|
5 |
+
'parameters': [7, 6, 30], # used in main paper
|
6 |
+
'layer': 5,
|
7 |
+
'ranks': [512, 8],
|
8 |
+
'checkpoints_path': [
|
9 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_5-rank_8.npy',
|
10 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_5-rank_512.npy',
|
11 |
+
],
|
12 |
+
},
|
13 |
+
'long_nose': {
|
14 |
+
'parameters': [5, 82, 30], # used in main paper
|
15 |
+
'layer': 5,
|
16 |
+
'ranks': [512, 8],
|
17 |
+
'checkpoints_path': [
|
18 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_5-rank_8.npy',
|
19 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_5-rank_512.npy',
|
20 |
+
],
|
21 |
+
},
|
22 |
+
'smile': {
|
23 |
+
'parameters': [4, 46, -30], # used in sup. material
|
24 |
+
'layer': 5,
|
25 |
+
'ranks': [512, 8],
|
26 |
+
'checkpoints_path': [
|
27 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_5-rank_8.npy',
|
28 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_5-rank_512.npy',
|
29 |
+
],
|
30 |
+
},
|
31 |
+
'open_mouth': {
|
32 |
+
'parameters': [4, 39, 30], # used in sup. material
|
33 |
+
'layer': 5,
|
34 |
+
'ranks': [512, 8],
|
35 |
+
'checkpoints_path': [
|
36 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_5-rank_8.npy',
|
37 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_5-rank_512.npy',
|
38 |
+
],
|
39 |
+
},
|
40 |
+
|
41 |
+
# Additional directions
|
42 |
+
'big_eyeballs': {
|
43 |
+
'parameters': [8, 27, 100],
|
44 |
+
'layer': 6,
|
45 |
+
'ranks': [512, 16],
|
46 |
+
'checkpoints_path': [
|
47 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_6-rank_16.npy',
|
48 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_6-rank_512.npy',
|
49 |
+
],
|
50 |
+
},
|
51 |
+
'wide_nose': {
|
52 |
+
'parameters': [15, 13, 100],
|
53 |
+
'layer': 6,
|
54 |
+
'ranks': [512, 16],
|
55 |
+
'checkpoints_path': [
|
56 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_6-rank_16.npy',
|
57 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_6-rank_512.npy',
|
58 |
+
],
|
59 |
+
},
|
60 |
+
'glance_left': {
|
61 |
+
'parameters': [8, 281, 50],
|
62 |
+
'layer': 6,
|
63 |
+
'ranks': [512, 16],
|
64 |
+
'checkpoints_path': [
|
65 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_6-rank_16.npy',
|
66 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_6-rank_512.npy',
|
67 |
+
],
|
68 |
+
},
|
69 |
+
'glance_right': {
|
70 |
+
'parameters': [8, 281, -70],
|
71 |
+
'layer': 6,
|
72 |
+
'ranks': [512, 16],
|
73 |
+
'checkpoints_path': [
|
74 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_6-rank_16.npy',
|
75 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_6-rank_512.npy',
|
76 |
+
],
|
77 |
+
},
|
78 |
+
'bald_forehead': {
|
79 |
+
'parameters': [3, 25, 100],
|
80 |
+
'layer': 6,
|
81 |
+
'ranks': [512, 16],
|
82 |
+
'checkpoints_path': [
|
83 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_6-rank_16.npy',
|
84 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_6-rank_512.npy',
|
85 |
+
],
|
86 |
+
},
|
87 |
+
'light_eyebrows': {
|
88 |
+
'parameters': [8, 4, 30],
|
89 |
+
'layer': 7,
|
90 |
+
'ranks': [512, 16],
|
91 |
+
'checkpoints_path': [
|
92 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
|
93 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
|
94 |
+
],
|
95 |
+
},
|
96 |
+
'dark_eyebrows': {
|
97 |
+
'parameters': [8, 9, 30],
|
98 |
+
'layer': 7,
|
99 |
+
'ranks': [512, 16],
|
100 |
+
'checkpoints_path': [
|
101 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
|
102 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
|
103 |
+
],
|
104 |
+
},
|
105 |
+
'no_eyebrows': {
|
106 |
+
'parameters': [8, 4, 50],
|
107 |
+
'layer': 7,
|
108 |
+
'ranks': [512, 16],
|
109 |
+
'checkpoints_path': [
|
110 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
|
111 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
|
112 |
+
],
|
113 |
+
},
|
114 |
+
'dark_eyes': {
|
115 |
+
'parameters': [11, 176, 50],
|
116 |
+
'layer': 7,
|
117 |
+
'ranks': [512, 16],
|
118 |
+
'checkpoints_path': [
|
119 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
|
120 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
|
121 |
+
],
|
122 |
+
},
|
123 |
+
'red_eyes': {
|
124 |
+
'parameters': [11, 109, 60],
|
125 |
+
'layer': 7,
|
126 |
+
'ranks': [512, 16],
|
127 |
+
'checkpoints_path': [
|
128 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
|
129 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
|
130 |
+
],
|
131 |
+
},
|
132 |
+
'eyes_short': {
|
133 |
+
'parameters': [11, 262, 70],
|
134 |
+
'layer': 7,
|
135 |
+
'ranks': [512, 16],
|
136 |
+
'checkpoints_path': [
|
137 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
|
138 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
|
139 |
+
],
|
140 |
+
},
|
141 |
+
'eyes_open': {
|
142 |
+
'parameters': [11, 28, 50],
|
143 |
+
'layer': 7,
|
144 |
+
'ranks': [512, 16],
|
145 |
+
'checkpoints_path': [
|
146 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
|
147 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
|
148 |
+
],
|
149 |
+
},
|
150 |
+
'eyes_close': {
|
151 |
+
'parameters': [11, 398, 80],
|
152 |
+
'layer': 7,
|
153 |
+
'ranks': [512, 16],
|
154 |
+
'checkpoints_path': [
|
155 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
|
156 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
|
157 |
+
],
|
158 |
+
},
|
159 |
+
'no_eyes': {
|
160 |
+
'parameters': [11, 0, -200],
|
161 |
+
'layer': 7,
|
162 |
+
'ranks': [512, 16],
|
163 |
+
'checkpoints_path': [
|
164 |
+
'./checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
|
165 |
+
'./checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
|
166 |
+
],
|
167 |
+
},
|
168 |
+
|
169 |
+
},
|
170 |
+
|
171 |
+
}
|
app.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from model import Model as Model
|
6 |
+
from annotated_directions import annotated_directions
|
7 |
+
device = torch.device('cpu')
|
8 |
+
|
9 |
+
torch.set_grad_enabled(False)
|
10 |
+
model_name = "stylegan2_ffhq1024"
|
11 |
+
|
12 |
+
directions = list(annotated_directions[model_name].keys())
|
13 |
+
|
14 |
+
|
15 |
+
def inference(seed, direction):
|
16 |
+
layer = annotated_directions[model_name][direction]['layer']
|
17 |
+
M = Model(model_name, trunc_psi=1.0, device=device, layer=layer)
|
18 |
+
M.ranks = annotated_directions[model_name][direction]['ranks']
|
19 |
+
|
20 |
+
# load the checkpoint
|
21 |
+
try:
|
22 |
+
M.Us = torch.Tensor(np.load(annotated_directions[model_name][direction]['checkpoints_path'][0])).to(device)
|
23 |
+
M.Uc = torch.Tensor(np.load(annotated_directions[model_name][direction]['checkpoints_path'][1])).to(device)
|
24 |
+
except KeyError:
|
25 |
+
raise KeyError('ERROR: No directions specified in ./annotated_directions.py for this model')
|
26 |
+
|
27 |
+
part, appearance, lam = annotated_directions[model_name][direction]['parameters']
|
28 |
+
|
29 |
+
Z, image, image2, part_img = M.edit_at_layer([[part]], [appearance], [lam], t=seed, Uc=M.Uc, Us=M.Us, noise=None)
|
30 |
+
|
31 |
+
dif = np.tile(((np.mean((image - image2)**2, -1)))[:,:,None], [1,1,3]).astype(np.uint8)
|
32 |
+
|
33 |
+
return Image.fromarray(np.concatenate([image, image2, dif], 1))
|
34 |
+
|
35 |
+
|
36 |
+
demo = gr.Interface(
|
37 |
+
fn=inference,
|
38 |
+
inputs=[gr.Slider(0, 1000, value=64), gr.Dropdown(directions, value='no_eyebrows')],
|
39 |
+
outputs=[gr.Image(type="pil", label="original | edited | mean-squared difference")],
|
40 |
+
title="PandA (ICLR'23) - FFHQ edit zoo",
|
41 |
+
description="Provides a quick interface to manipulate pre-annotated directions with pre-trained global parts and appearances factors. Note that we use the free CPU tier, so synthesis takes about 10 seconds.",
|
42 |
+
article="Check out the full demo and paper at: <a href='https://github.com/james-oldfield/PandA'>https://github.com/james-oldfield/PandA</a>"
|
43 |
+
)
|
44 |
+
demo.launch()
|
checkpoints/Uc-name_stylegan2_afhqdog512-layer_8-rank_512.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c3e40871d2c686e6552b44dd6e652a60c1596beb9f7a9d5ffb9d9ffdd5e97e53
|
3 |
+
size 1048704
|
checkpoints/Uc-name_stylegan2_ffhq1024-layer_5-rank_512.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5e11bfd6bd179944dd89b6ff6e63829e595bdb7f926a4ff4cbc23e7adc91abb4
|
3 |
+
size 1048704
|
checkpoints/Uc-name_stylegan2_ffhq1024-layer_6-rank_512.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:55ed230d4f0dbaf1ed7849ada16cb95b29989efcd025c1dd96eda8c99245f287
|
3 |
+
size 1048704
|
checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b7a2f1d9d3034bcec65aa5463986a1ddab35ba73d9e3c639499e02f33e58ece3
|
3 |
+
size 1048704
|
checkpoints/Us-name_stylegan2_afhqdog512-layer_8-rank_8.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4808e4fc3ab102d53d362ab5a783ac09a994cf826a6394e569b08148ffc985ca
|
3 |
+
size 131200
|
checkpoints/Us-name_stylegan2_ffhq1024-layer_5-rank_8.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a154209cd3757fe5912e0530a1431e3aac67b2f34e73c6b73326af4a9f6ce0f4
|
3 |
+
size 8320
|
checkpoints/Us-name_stylegan2_ffhq1024-layer_6-rank_16.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e6b7ef7e612b8b1840cf83b9e6f88438cd6858665e95cd4758ae0e5e8f06e374
|
3 |
+
size 65664
|
checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fdb99dcf6d5ccbad8e9da5709a30837b16c01686f5d68fd1922a8e8fb97c9e23
|
3 |
+
size 65664
|
demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
environment.yml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: PandA
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.8.5
|
7 |
+
- pip=20.3
|
8 |
+
- cudatoolkit=11.3
|
9 |
+
- pytorch=1.11.0
|
10 |
+
- torchvision=0.12.0
|
11 |
+
- numpy=1.19.2
|
12 |
+
- pip:
|
13 |
+
- tensorly==0.7.0
|
14 |
+
- opencv-python==4.1.2.30
|
15 |
+
- boto3==1.21.8
|
16 |
+
- botocore==1.24.8
|
17 |
+
- imageio==2.14.1
|
18 |
+
- matplotlib==3.5.1
|
19 |
+
- nltk==3.7
|
20 |
+
- numpy
|
21 |
+
- Pillow==9.0.1
|
22 |
+
- requests==2.27.0
|
23 |
+
- bs4
|
24 |
+
- ipykernel
|
25 |
+
- jupyterlab
|
ffhq-edit-zoo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
localize-concepts.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model.py
ADDED
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from networks.load_generator import load_generator
|
2 |
+
from networks.genforce.utils.visualizer import postprocess_image as postprocess
|
3 |
+
from networks.biggan import one_hot_from_names, truncated_noise_sample
|
4 |
+
from networks.stylegan3.load_stylegan3 import make_transform
|
5 |
+
|
6 |
+
from matplotlib import pyplot as plt
|
7 |
+
|
8 |
+
from utils import plot_masks, plot_colours, mapRange
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
import tensorly as tl
|
14 |
+
tl.set_backend('pytorch')
|
15 |
+
|
16 |
+
|
17 |
+
class Model():
|
18 |
+
def __init__(self, model_name, t=0, layer=5, trunc_psi=1.0, trunc_layers=18, device='cuda', biggan_classes=['fox']):
|
19 |
+
"""
|
20 |
+
Instantiate the model for decomposition and/or local image editing.
|
21 |
+
|
22 |
+
Parameters
|
23 |
+
----------
|
24 |
+
model_name : string
|
25 |
+
Name of architecture and dataset--one of the items in ./networks/genforce/models/model_zoo.py.
|
26 |
+
t : int
|
27 |
+
Random seed for the generator (to generate a sample image).
|
28 |
+
layer : int
|
29 |
+
Intermediate layer at which to perform the decomposition.
|
30 |
+
trunc_psi : float
|
31 |
+
Truncation value in [0, 1].
|
32 |
+
trunc_layers : int
|
33 |
+
Number of layers at which to apply truncation.
|
34 |
+
device : string
|
35 |
+
Device to store the tensors on.
|
36 |
+
biggan_classes : list
|
37 |
+
List of strings specifying imagenet classes of interest (e.g. ['alp', 'breakwater']).
|
38 |
+
"""
|
39 |
+
self.gan_type = model_name.split('_')[0]
|
40 |
+
self.model_name = model_name
|
41 |
+
self.randomize_noise = False
|
42 |
+
self.device = device
|
43 |
+
self.biggan_classes = biggan_classes
|
44 |
+
self.layer = layer # layer to decompose
|
45 |
+
self.start = 0 if 'stylegan2' in self.gan_type else 2
|
46 |
+
self.trunc_psi = trunc_psi
|
47 |
+
self.trunc_layers = trunc_layers
|
48 |
+
|
49 |
+
self.generator = load_generator(model_name, device)
|
50 |
+
noise = torch.Tensor(np.random.randn(1, self.generator.z_space_dim)).to(self.device)
|
51 |
+
z, image = self.sample(noise, layer=layer, trunc_psi=trunc_psi, trunc_layers=trunc_layers, verbose=True)
|
52 |
+
|
53 |
+
self.c = z.shape[1]
|
54 |
+
self.s = z.shape[2]
|
55 |
+
self.image = image
|
56 |
+
|
57 |
+
def HOSVD(self, batch_size=10, n_iters=100):
|
58 |
+
"""
|
59 |
+
Initialises the appearance basis A. In particular, computes the left-singular vectors of the channel mode's scatter matrix.
|
60 |
+
|
61 |
+
Note: total samples used is batch_size * n_iters
|
62 |
+
|
63 |
+
Parameters
|
64 |
+
----------
|
65 |
+
batch_size : int
|
66 |
+
Number of activations to sample in a single go.
|
67 |
+
n_iters : int
|
68 |
+
Number of times to sample `batch_size`-many activations.
|
69 |
+
"""
|
70 |
+
np.random.seed(0)
|
71 |
+
torch.manual_seed(0)
|
72 |
+
|
73 |
+
with torch.no_grad():
|
74 |
+
Z = torch.zeros((batch_size * n_iters, self.c, self.s, self.s), device=self.device)
|
75 |
+
|
76 |
+
# note: perform in loops to have a larger effective batch size
|
77 |
+
print('Starting loops...')
|
78 |
+
for i in range(n_iters):
|
79 |
+
np.random.seed(i)
|
80 |
+
torch.manual_seed(i)
|
81 |
+
noise = torch.Tensor(np.random.randn(batch_size, self.generator.z_space_dim)).to(self.device)
|
82 |
+
z, _ = self.sample(noise, layer=self.layer, partial=True)
|
83 |
+
|
84 |
+
Z[(batch_size * i):(batch_size * (i + 1))] = z
|
85 |
+
|
86 |
+
Z = Z.view([-1, self.c, self.s**2])
|
87 |
+
print(f'Generated {batch_size * n_iters} gan samples...')
|
88 |
+
|
89 |
+
scat = 0
|
90 |
+
for _, x in enumerate(Z):
|
91 |
+
# mode-3 unfolding in the paper, but in PyTorch channel mode is first.
|
92 |
+
m_unfold = tl.unfold(x, 0)
|
93 |
+
scat += m_unfold @ m_unfold.T
|
94 |
+
|
95 |
+
self.Uc_init, _, _ = np.linalg.svd((scat / len(Z)).cpu().numpy())
|
96 |
+
self.Uc_init = torch.Tensor(self.Uc_init).to(self.device)
|
97 |
+
|
98 |
+
print('... HOSVD done')
|
99 |
+
|
100 |
+
def decompose(self, ranks=[512, 8], lr=1e-8, batch_size=1, its=10000, log_modulo=1000, hosvd_init=True, stochastic=True, n_iters=1, verbose=True):
|
101 |
+
"""
|
102 |
+
Performs the decomposition in the paper. In particular, Algorithm 1.,
|
103 |
+
either with a non-fixed batch of samples (stochastic=True), or descends the full gradients.
|
104 |
+
|
105 |
+
Parameters
|
106 |
+
----------
|
107 |
+
ranks : list
|
108 |
+
List of integers specifying the R_C and R_S, the ranks--i.e. number of parts and appearances respectively.
|
109 |
+
lr : float
|
110 |
+
Learning rate the projected gradient descent.
|
111 |
+
batch_size : int
|
112 |
+
Number of samples in each batch.
|
113 |
+
its : int
|
114 |
+
Total number of iterations.
|
115 |
+
log_modulo : int
|
116 |
+
Parameter used to control how often "training" information is displayed.
|
117 |
+
hosvd_init : bool
|
118 |
+
Initialise appearance factors from HOSVD? (else from random normal).
|
119 |
+
stochastic : bool
|
120 |
+
Sample the batch again each iteration? Else descent full gradients
|
121 |
+
n_iters : int
|
122 |
+
Number of `batch_size`-many samples to take (for full gradient).
|
123 |
+
The total activations are sampled in batches in a loop to enable it to fit in memory.
|
124 |
+
verbose : bool
|
125 |
+
Prints extra information.
|
126 |
+
"""
|
127 |
+
self.ranks = ranks
|
128 |
+
np.random.seed(0)
|
129 |
+
torch.manual_seed(0)
|
130 |
+
|
131 |
+
#######################
|
132 |
+
# init from HOSVD, else random normal
|
133 |
+
Uc = self.Uc_init[:, :ranks[0]].detach().clone().to(self.device) if hosvd_init else torch.randn(self.Uc_init.shape[0], ranks[0]).detach().clone().to(self.device) * 0.01
|
134 |
+
Us = torch.Tensor(np.random.uniform(0, 0.01, size=[self.s**2, ranks[1]])).to(self.device)
|
135 |
+
#######################
|
136 |
+
|
137 |
+
print(f'Uc shape: {Uc.shape}, Us shape: {Us.shape}')
|
138 |
+
|
139 |
+
with torch.no_grad():
|
140 |
+
zeros = torch.zeros_like(Us, device=self.device)
|
141 |
+
Us = torch.maximum(Us, zeros)
|
142 |
+
|
143 |
+
# use a fixed batch (i.e. descend the full gradient)
|
144 |
+
if not stochastic:
|
145 |
+
Z = torch.zeros((batch_size * n_iters, self.c, self.s, self.s), device=self.device)
|
146 |
+
|
147 |
+
# note: perform in loops to have a larger effective batch size
|
148 |
+
print(f'Starting loops, total Z shape: {Z.shape}...')
|
149 |
+
for i in range(n_iters):
|
150 |
+
np.random.seed(i)
|
151 |
+
torch.manual_seed(i)
|
152 |
+
noise = torch.Tensor(np.random.randn(batch_size, self.generator.z_space_dim)).to(self.device)
|
153 |
+
z, _ = self.sample(noise, layer=self.layer, partial=True)
|
154 |
+
|
155 |
+
Z[(batch_size * i):(batch_size * (i + 1))] = z
|
156 |
+
|
157 |
+
for t in range(its):
|
158 |
+
np.random.seed(t)
|
159 |
+
torch.manual_seed(t)
|
160 |
+
|
161 |
+
# resample the batch, if stochastic
|
162 |
+
if stochastic:
|
163 |
+
noise = torch.Tensor(np.random.randn(batch_size, self.generator.z_space_dim)).to(self.device)
|
164 |
+
Z, _ = self.sample(noise, layer=self.layer, partial=True)
|
165 |
+
|
166 |
+
if verbose:
|
167 |
+
# reconstruct (for visualisation)
|
168 |
+
coords = tl.tenalg.multi_mode_dot(Z.view(-1, self.c, self.s**2).float(), [Uc.T, Us.T], transpose=False, modes=[1, 2])
|
169 |
+
Z_rec = tl.tenalg.multi_mode_dot(coords, [Uc, Us], transpose=False, modes=[1, 2])
|
170 |
+
|
171 |
+
self.rec_loss = torch.mean(torch.norm(Z.view(-1, self.c, self.s**2).float() - Z_rec, p='fro', dim=[1, 2]) ** 2)
|
172 |
+
|
173 |
+
# Update S
|
174 |
+
z = Z.view(-1, self.c, self.s**2).float()
|
175 |
+
Us_g = -4 * (torch.transpose(z,1,2)@Uc@Uc.T@z@Us) + \
|
176 |
+
2 * (Us@Us.T@torch.transpose(z,1,2)@Uc@Uc.T@Uc@Uc.T@z@Us + torch.transpose(z,1,2)@Uc@Uc.T@Uc@Uc.T@z@Us@Us.T@Us)
|
177 |
+
Us_g = torch.sum(Us_g, 0)
|
178 |
+
|
179 |
+
Us = Us - lr * Us_g
|
180 |
+
# --- projection step ---a
|
181 |
+
Us = torch.maximum(Us, zeros)
|
182 |
+
|
183 |
+
# Update C
|
184 |
+
Uc_g = -4 * (z@Us@Us.T@torch.transpose(z,1,2)@Uc) + \
|
185 |
+
2 * (Uc@Uc.T@z@Us@Us.T@Us@Us.T@torch.transpose(z,1,2)@Uc + z@Us@Us.T@Us@Us.T@torch.transpose(z,1,2)@Uc@Uc.T@Uc)
|
186 |
+
Uc_g = torch.sum(Uc_g, 0)
|
187 |
+
Uc = Uc - lr * Uc_g
|
188 |
+
|
189 |
+
self.Us = Us
|
190 |
+
self.Uc = Uc
|
191 |
+
|
192 |
+
if t % log_modulo == 0 and verbose:
|
193 |
+
print(f'ITERATION: {t}')
|
194 |
+
z, x = self.sample(noise, layer=self.layer, partial=False)
|
195 |
+
|
196 |
+
# here we display the learnt parts factors and also overlay them over the images to visualise.
|
197 |
+
plot_masks(Us.T, r=min(ranks[-1], 32), s=self.s)
|
198 |
+
plt.show()
|
199 |
+
plot_colours(x, Us.T, r=ranks[-1], s=self.s, seed=-1)
|
200 |
+
plt.show()
|
201 |
+
|
202 |
+
def decompose_autograd(self, ranks=[512, 8], lr=1e-8, batch_size=1, its=10000, log_modulo=1000, verbose=True, hosvd_init=True):
|
203 |
+
"""
|
204 |
+
Performs the same decomposition in the paper, only uses autograd with Adam optimizer (and projected gradient descent).
|
205 |
+
|
206 |
+
Parameters
|
207 |
+
----------
|
208 |
+
ranks : list
|
209 |
+
List of integers specifying the R_C and R_S, the ranks--i.e. number of parts and appearances respectively.
|
210 |
+
lr : float
|
211 |
+
Learning rate the projected gradient descent.
|
212 |
+
batch_size : int
|
213 |
+
Number of samples in each batch.
|
214 |
+
its : int
|
215 |
+
Total number of iterations.
|
216 |
+
log_modulo : int
|
217 |
+
Parameter used to control how often "training" information is displayed.
|
218 |
+
hosvd_init : bool
|
219 |
+
Initialise appearance factors from HOSVD? (else from random normal).
|
220 |
+
verbose : bool
|
221 |
+
Prints extra information.
|
222 |
+
"""
|
223 |
+
self.ranks = ranks
|
224 |
+
np.random.seed(0)
|
225 |
+
torch.manual_seed(0)
|
226 |
+
|
227 |
+
#######################
|
228 |
+
# init from HOSVD, else random normal
|
229 |
+
Uc = torch.nn.Parameter(self.Uc_init[:, :ranks[0]].detach().clone().to(self.device), requires_grad=True) \
|
230 |
+
if hosvd_init else torch.nn.Parameter(torch.randn(self.Uc_init.shape[0], ranks[0]).detach().clone().to(self.device) * 0.01)
|
231 |
+
Us = torch.nn.Parameter(torch.Tensor(np.random.uniform(0, 0.01, size=[self.s**2, ranks[1]])).to(self.device), requires_grad=True)
|
232 |
+
#######################
|
233 |
+
optimizerS = torch.optim.Adam([Us], lr=lr)
|
234 |
+
optimizerC = torch.optim.Adam([Uc], lr=lr)
|
235 |
+
|
236 |
+
print(f'Uc shape: {Uc.shape}, Us shape: {Us.shape}')
|
237 |
+
|
238 |
+
zeros = torch.zeros_like(Us, device=self.device)
|
239 |
+
for t in range(its):
|
240 |
+
np.random.seed(t)
|
241 |
+
torch.manual_seed(t)
|
242 |
+
|
243 |
+
noise = torch.Tensor(np.random.randn(batch_size, self.generator.z_space_dim)).to(self.device)
|
244 |
+
Z, _ = self.sample(noise, layer=self.layer, partial=True)
|
245 |
+
|
246 |
+
# Update S
|
247 |
+
# reconstruct
|
248 |
+
coords = tl.tenalg.multi_mode_dot(Z.view(-1, self.c, self.s**2).float(), [Uc.T, Us.T], transpose=False, modes=[1, 2])
|
249 |
+
Z_rec = tl.tenalg.multi_mode_dot(coords, [Uc, Us], transpose=False, modes=[1, 2])
|
250 |
+
|
251 |
+
rec_loss = torch.mean(torch.norm(Z.view(-1, self.c, self.s**2).float() - Z_rec, p='fro', dim=[1, 2]) ** 2)
|
252 |
+
rec_loss.backward(retain_graph=True)
|
253 |
+
|
254 |
+
optimizerS.step()
|
255 |
+
# --- projection step ---
|
256 |
+
Us.data = torch.maximum(Us.data, zeros)
|
257 |
+
optimizerS.zero_grad()
|
258 |
+
optimizerC.zero_grad()
|
259 |
+
|
260 |
+
# Update C
|
261 |
+
# reconstruct with updated Us
|
262 |
+
coords = tl.tenalg.multi_mode_dot(Z.view(-1, self.c, self.s**2).float(), [Uc.T, Us.T], transpose=False, modes=[1, 2])
|
263 |
+
Z_rec = tl.tenalg.multi_mode_dot(coords, [Uc, Us], transpose=False, modes=[1, 2])
|
264 |
+
|
265 |
+
rec_loss = torch.mean(torch.norm(Z.view(-1, self.c, self.s**2).float() - Z_rec, p='fro', dim=[1, 2]) ** 2)
|
266 |
+
rec_loss.backward()
|
267 |
+
optimizerC.step()
|
268 |
+
optimizerS.zero_grad()
|
269 |
+
optimizerC.zero_grad()
|
270 |
+
|
271 |
+
self.Us = Us
|
272 |
+
self.Uc = Uc
|
273 |
+
|
274 |
+
with torch.no_grad():
|
275 |
+
if t % log_modulo == 0 and verbose:
|
276 |
+
print(f'Iteration {t} -- rec {rec_loss}')
|
277 |
+
|
278 |
+
noise = torch.Tensor(np.random.randn(batch_size, self.generator.z_space_dim)).to(self.device)
|
279 |
+
Z, x = self.sample(noise, layer=self.layer, partial=False)
|
280 |
+
|
281 |
+
plot_masks(Us.T, r=min(ranks[-1], 32), s=self.s)
|
282 |
+
plt.show()
|
283 |
+
plot_colours(x, Us.T, r=ranks[-1], s=self.s, seed=-1)
|
284 |
+
plt.show()
|
285 |
+
|
286 |
+
def refine(self, Z, image, lr=1e-8, its=1000, log_modulo=250, verbose=True):
|
287 |
+
"""
|
288 |
+
Performs the "refinement" step described in the paper, for a given sample Z.
|
289 |
+
|
290 |
+
Parameters
|
291 |
+
----------
|
292 |
+
Z : torch.Tensor
|
293 |
+
Intermediate activations for target refinement.
|
294 |
+
image : np.array
|
295 |
+
Corresponding image for Z (purely for visualisation purposes).
|
296 |
+
lr : float
|
297 |
+
Learning rate the projected gradient descent.
|
298 |
+
its : int
|
299 |
+
Total number of iterations.
|
300 |
+
log_modulo : int
|
301 |
+
Parameter used to control how often "training" information is displayed.
|
302 |
+
verbose : bool
|
303 |
+
Prints extra information.
|
304 |
+
|
305 |
+
Returns
|
306 |
+
-------
|
307 |
+
UsR : torch.Tensor
|
308 |
+
The refined factors \tilde{P}_i.
|
309 |
+
"""
|
310 |
+
np.random.seed(0)
|
311 |
+
torch.manual_seed(0)
|
312 |
+
|
313 |
+
#######################
|
314 |
+
# init from global spatial factors
|
315 |
+
UsR = self.Us.clone()
|
316 |
+
Uc = self.Uc
|
317 |
+
#######################
|
318 |
+
|
319 |
+
zeros = torch.zeros_like(self.Us, device=self.device)
|
320 |
+
for t in range(its):
|
321 |
+
with torch.no_grad():
|
322 |
+
z = Z.view(-1, self.c, self.s**2).float()
|
323 |
+
# descend refinement term's gradient
|
324 |
+
UsR_g = -4 * (torch.transpose(z,1,2)@Uc@Uc.T@z@UsR) + \
|
325 |
+
2 * (UsR@UsR.T@torch.transpose(z,1,2)@Uc@Uc.T@Uc@Uc.T@z@UsR + torch.transpose(z,1,2)@Uc@Uc.T@Uc@Uc.T@z@UsR@UsR.T@UsR)
|
326 |
+
UsR_g = torch.sum(UsR_g, 0)
|
327 |
+
|
328 |
+
# Update S
|
329 |
+
UsR = UsR - lr * UsR_g
|
330 |
+
# PGD step
|
331 |
+
UsR = torch.maximum(UsR, zeros)
|
332 |
+
|
333 |
+
if ((t + 1) % log_modulo == 0 and verbose):
|
334 |
+
print(f'iteration {t}')
|
335 |
+
|
336 |
+
plot_masks(UsR.T, s=self.s, r=min(self.ranks[-1], 16))
|
337 |
+
plt.show()
|
338 |
+
plot_colours(image, UsR.T, s=self.s, r=self.ranks[-1], seed=-1, alpha=0.9)
|
339 |
+
plt.show()
|
340 |
+
|
341 |
+
return UsR
|
342 |
+
|
343 |
+
def edit_at_layer(self, part, appearance, lam, t, Uc, Us, noise=None, b_idx=0):
|
344 |
+
"""
|
345 |
+
Performs the "refinement" step described in the paper, for a given sample Z.
|
346 |
+
|
347 |
+
Parameters
|
348 |
+
----------
|
349 |
+
part : list
|
350 |
+
List of ints containing the part(s) (column of Us) at which to edit.
|
351 |
+
appearance : list
|
352 |
+
List of ints containing the appearance (column of Uc) to apply at the corresponding part(s).
|
353 |
+
lam : list
|
354 |
+
List of ints containing the magnitude for each edit.
|
355 |
+
t : int
|
356 |
+
Random seed to edit
|
357 |
+
Uc : np.array
|
358 |
+
Learnt appearance factors
|
359 |
+
Us : np.array
|
360 |
+
Learnt parts factors
|
361 |
+
noise : np.array
|
362 |
+
If specified, the target latent code itself to edit (i.e. instead of providing than a random seed number).
|
363 |
+
b_idx : int
|
364 |
+
Index of biggan categories to use.
|
365 |
+
|
366 |
+
Returns
|
367 |
+
-------
|
368 |
+
Z : torch.Tensor
|
369 |
+
The intermediate activation at layer self.L
|
370 |
+
image : np.array
|
371 |
+
The original image for sample `t` or from latent code `noise`.
|
372 |
+
image2 : np.array
|
373 |
+
The edited image.
|
374 |
+
part : np.array
|
375 |
+
The part used to edit.
|
376 |
+
"""
|
377 |
+
with torch.no_grad():
|
378 |
+
if noise is None:
|
379 |
+
np.random.seed(t)
|
380 |
+
torch.manual_seed(t)
|
381 |
+
noise = torch.Tensor(np.random.randn(1, self.generator.z_space_dim)).to(self.device)
|
382 |
+
else:
|
383 |
+
np.random.seed(0)
|
384 |
+
torch.manual_seed(0)
|
385 |
+
|
386 |
+
direc = 0
|
387 |
+
for i in range(len(appearance)):
|
388 |
+
a = Uc[:, appearance[i]]
|
389 |
+
p = torch.sum(Us[:, part[i]], dim=-1).reshape([self.s, self.s])
|
390 |
+
p = mapRange(p, torch.min(p), torch.max(p), 0.0, 1.0)
|
391 |
+
|
392 |
+
# here, we basically form a rank-1 "tensor", to add to the target sample's activations.
|
393 |
+
# intuitively, the non-zero spatial positions of the part are filled with the appearance vector.
|
394 |
+
direc += lam[i] * tl.tenalg.outer([a, p])
|
395 |
+
|
396 |
+
if self.gan_type in ['stylegan', 'stylegan2']:
|
397 |
+
noise = self.generator.mapping(noise)['w']
|
398 |
+
noise_trunc = self.generator.truncation(noise, trunc_psi=self.trunc_psi, trunc_layers=self.trunc_layers)
|
399 |
+
|
400 |
+
Z = self.generator.synthesis(noise_trunc, start=self.start, stop=self.layer)['x']
|
401 |
+
|
402 |
+
x = self.generator.synthesis(noise_trunc, x=Z, start=self.layer)['image']
|
403 |
+
x_prime = self.generator.synthesis(noise_trunc, x=Z + direc, start=self.layer)['image']
|
404 |
+
elif 'pggan' in self.gan_type:
|
405 |
+
Z = self.generator(noise, start=self.start, stop=self.layer)['x']
|
406 |
+
|
407 |
+
x = self.generator(Z, start=self.layer)['image']
|
408 |
+
x_prime = self.generator(Z + direc, start=self.layer)['image']
|
409 |
+
elif 'biggan' in self.gan_type:
|
410 |
+
print(f'Choosing a {self.biggan_classes[b_idx]}')
|
411 |
+
class_vector = torch.tensor(one_hot_from_names([self.biggan_classes[b_idx]]), device=self.device)
|
412 |
+
noise_vector = torch.tensor(truncated_noise_sample(truncation=self.trunc_psi, batch_size=1, seed=t), device=self.device)
|
413 |
+
|
414 |
+
result = self.generator(noise_vector, class_vector, self.trunc_psi, stop=self.layer)
|
415 |
+
Z, cond_vector = result['z'], result['cond_vector']
|
416 |
+
x = self.generator(Z, class_vector, self.trunc_psi, cond_vector=cond_vector, start=self.layer)['z']
|
417 |
+
x_prime = self.generator(Z + direc, class_vector, self.trunc_psi, cond_vector=cond_vector, start=self.layer)['z']
|
418 |
+
elif 'stylegan3' in self.gan_type:
|
419 |
+
label = torch.zeros([1, 0], device=self.device)
|
420 |
+
Z = self.generator(noise, label, stop=self.layer, truncation_psi=self.trunc_psi, noise_mode='const')
|
421 |
+
|
422 |
+
x = self.generator(noise, label, x=Z, start=self.layer, stop=None, truncation_psi=self.trunc_psi, noise_mode='const')
|
423 |
+
x_prime = self.generator(noise, label, x=Z + direc, start=self.layer, stop=None, truncation_psi=self.trunc_psi, noise_mode='const')
|
424 |
+
|
425 |
+
image = np.array(Image.fromarray(postprocess(x.cpu().numpy())[0]).resize((256, 256)))
|
426 |
+
image2 = np.array(Image.fromarray(postprocess(x_prime.cpu().numpy())[0]).resize((256, 256)))
|
427 |
+
|
428 |
+
part = np.array(Image.fromarray(p.detach().cpu().numpy() * 255).convert('RGB').resize((256, 256), Image.NEAREST))
|
429 |
+
return Z, image, image2, part
|
430 |
+
|
431 |
+
def sample(self, noise, layer=5, partial=False, trunc_psi=1.0, trunc_layers=18, verbose=False):
|
432 |
+
"""
|
433 |
+
Samples intermediate feature maps and resulting image the desired generator.
|
434 |
+
|
435 |
+
Parameters
|
436 |
+
----------
|
437 |
+
noise : np.array
|
438 |
+
(batch_size, z_dim)-dim random standard gaussian noise.
|
439 |
+
layer : int
|
440 |
+
Intermediate layer at which to return intermediate features.
|
441 |
+
partial : bool
|
442 |
+
Perform full forward pass, and return image too? or just intermediate activations at layer number `layer`?
|
443 |
+
trunc_psi : float
|
444 |
+
Truncation value in [0, 1].
|
445 |
+
trunc_layers : int
|
446 |
+
Number of layers at which to apply truncation.
|
447 |
+
biggan_classes : list
|
448 |
+
List of strings specifying imagenet classes of interest (e.g. ['alp', 'breakwater']).
|
449 |
+
verbose : bool
|
450 |
+
Print out additional information?
|
451 |
+
|
452 |
+
Returns
|
453 |
+
-------
|
454 |
+
Z : torch.Tensor
|
455 |
+
The intermediate activations of shape [C, H, W].
|
456 |
+
image : np.array
|
457 |
+
Output RGB image.
|
458 |
+
"""
|
459 |
+
with torch.no_grad():
|
460 |
+
if self.gan_type in ['stylegan', 'stylegan2']:
|
461 |
+
noise = self.generator.mapping(noise)['w']
|
462 |
+
noise_trunc = self.generator.truncation(noise, trunc_psi=trunc_psi, trunc_layers=trunc_layers)
|
463 |
+
Z = self.generator.synthesis(noise_trunc, start=self.start, stop=layer)['x']
|
464 |
+
if not partial:
|
465 |
+
x = self.generator.synthesis(noise_trunc, x=Z, start=layer)['image']
|
466 |
+
elif 'pggan' in self.gan_type:
|
467 |
+
Z = self.generator(noise, start=self.start, stop=layer)['x']
|
468 |
+
if not partial:
|
469 |
+
x = self.generator(Z, start=layer)['image']
|
470 |
+
elif 'biggan' in self.gan_type:
|
471 |
+
if verbose:
|
472 |
+
print(f'Using BigGAN class names: {", ".join(self.biggan_classes)}')
|
473 |
+
|
474 |
+
class_vector = torch.tensor(one_hot_from_names(list(np.random.choice(self.biggan_classes, noise.shape[0])), batch_size=noise.shape[0]), device=self.device)
|
475 |
+
noise_vector = torch.tensor(truncated_noise_sample(truncation=self.trunc_psi, batch_size=noise.shape[0]), device=self.device)
|
476 |
+
|
477 |
+
result = self.generator(noise_vector, class_vector, self.trunc_psi, stop=layer)
|
478 |
+
Z = result['z']
|
479 |
+
cond_vector = result['cond_vector']
|
480 |
+
|
481 |
+
if not partial:
|
482 |
+
x = self.generator(Z, class_vector, self.trunc_psi, cond_vector=cond_vector, start=layer)['z']
|
483 |
+
elif 'stylegan3' in self.gan_type:
|
484 |
+
label = torch.zeros([noise.shape[0], 0], device=self.device)
|
485 |
+
if hasattr(self.generator.synthesis, 'input'):
|
486 |
+
m = np.linalg.inv(make_transform((0,0), 0))
|
487 |
+
self.generator.synthesis.input.transform.copy_(torch.from_numpy(m))
|
488 |
+
|
489 |
+
Z = self.generator(noise, label, x=None, start=0, stop=layer, truncation_psi=trunc_psi, noise_mode='const')
|
490 |
+
if not partial:
|
491 |
+
x = self.generator(noise, label, x=Z, start=layer, stop=None, truncation_psi=trunc_psi, noise_mode='const')
|
492 |
+
|
493 |
+
if verbose:
|
494 |
+
print(f'-- Partial Z shape at layer {layer}: {Z.shape}')
|
495 |
+
|
496 |
+
if partial:
|
497 |
+
return Z, None
|
498 |
+
else:
|
499 |
+
image = postprocess(x.detach().cpu().numpy())
|
500 |
+
image = np.array(Image.fromarray(image[0]).resize((256, 256)))
|
501 |
+
return Z, image
|
502 |
+
|
503 |
+
def save(self):
|
504 |
+
Uc_path = f'./checkpoints/Uc-name_{self.model_name}-layer_{self.layer}-rank_{self.ranks[0]}.npy'
|
505 |
+
Us_path = f'./checkpoints/Us-name_{self.model_name}-layer_{self.layer}-rank_{self.ranks[1]}.npy'
|
506 |
+
|
507 |
+
np.save(Us_path, self.Us.detach().cpu().numpy())
|
508 |
+
np.save(Uc_path, self.Uc.detach().cpu().numpy())
|
509 |
+
|
510 |
+
print(f'Saved factors to {Us_path}, {Uc_path}')
|
readme.md
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PandA: Unsupervised Learning of Parts and Appearances in the Feature Maps of GANs
|
2 |
+
|
3 |
+
## [ [paper](https://openreview.net/pdf?id=iUdSB2kK9GY) | [project page](http://eecs.qmul.ac.uk/~jo001/PandA/) | [video](https://www.youtube.com/watch?v=1KY055goKP0) | [edit zoo](https://colab.research.google.com/github/james-oldfield/PandA/blob/main/ffhq-edit-zoo.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/james-oldfield/PandA/blob/main/demo.ipynb) ]
|
4 |
+
|
5 |
+
![main.jpg](./images/main.jpg)
|
6 |
+
|
7 |
+
> **PandA: Unsupervised Learning of Parts and Appearances in the Feature Maps of GANs**<br>
|
8 |
+
> James Oldfield, Christos Tzelepis, Yannis Panagakis, Mihalis A. Nicolaou, and Ioannis Patras<br>
|
9 |
+
> *International Conference on Learning Representations (ICLR)*, 2023 <br>
|
10 |
+
> https://arxiv.org/abs/2206.00048 <br>
|
11 |
+
>
|
12 |
+
> **Abstract**: Recent advances in the understanding of Generative Adversarial Networks (GANs) have led to remarkable progress in visual editing and synthesis tasks, capitalizing on the rich semantics that are embedded in the latent spaces of pre-trained GANs. However, existing methods are often tailored to specific GAN architectures and are limited to either discovering global semantic directions that do not facilitate localized control, or require some form of supervision through manually provided regions or segmentation masks. In this light, we present an architecture-agnostic approach that jointly discovers factors representing spatial parts and their appearances in an entirely unsupervised fashion. These factors are obtained by applying a semi-nonnegative tensor factorization on the feature maps, which in turn enables context-aware local image editing with pixel-level control. In addition, we show that the discovered appearance factors correspond to saliency maps that localize concepts of interest, without using any labels. Experiments on a wide range of GAN architectures and datasets show that, in comparison to the state of the art, our method is far more efficient in terms of training time and, most importantly, provides much more accurate localized control.
|
13 |
+
|
14 |
+
![cat-gif](./images/cat-eye-control.gif)
|
15 |
+
> An example of using our learnt appearances and semantic parts for local image editing.
|
16 |
+
|
17 |
+
## Experiments
|
18 |
+
|
19 |
+
We provide a number of notebooks to reproduce the experiments in the paper and to explore the model. Please see the following notebooks:
|
20 |
+
|
21 |
+
# [`./demo.ipynb`](./demo.ipynb)
|
22 |
+
|
23 |
+
This notebook contains the code to learn the parts and appearance factors at a target layer in a target GAN. Contains code for local image editing using the learnt parts, and provides code for refining the parts factors.
|
24 |
+
|
25 |
+
| Local image editing (at the learnt semantic parts) | |
|
26 |
+
| :-- | :-- |
|
27 |
+
| ![image](./images/l8-t645-Rs16-Rc512-rTrue-lam[100]-p[6]-start-end.gif) | ![image](./images/l8-t16-Rs8-Rc512-rTrue-lam-150-p2-start-end.gif)
|
28 |
+
|
29 |
+
# [`./localize-concepts.ipynb`](./localize-concepts.ipynb)
|
30 |
+
|
31 |
+
Provides code to localize/visualize concepts of interest for a model/dataset of interest (setup for the "background" concept in `stylegan2_afhqdog512` as an example).
|
32 |
+
|
33 |
+
| Localizing the learnt "background" concept vector |
|
34 |
+
| :-- |
|
35 |
+
| ![image](./images/mask-bg-42.gif) ![image](./images/mask-bg-83.gif) ![image](./images/mask-bg-29.gif) |
|
36 |
+
|
37 |
+
# [`./ffhq-edit-zoo.ipynb`](./ffhq-edit-zoo.ipynb)
|
38 |
+
|
39 |
+
Quickly produce edits with annotated directions with pre-trained factors on FFHQ StyleGAN2.
|
40 |
+
|
41 |
+
| Local image editing: "Big eyes" |
|
42 |
+
| :-- |
|
43 |
+
| ![image](./images/qualitative.png) |
|
44 |
+
|
45 |
+
## Setup
|
46 |
+
|
47 |
+
Should you wish to run the notebooks, please consult this section below:
|
48 |
+
|
49 |
+
### Install
|
50 |
+
First, please install the dependencies with `pip install -r requirements.txt`, or alternatively with conda using `conda env create -f environment.yml`
|
51 |
+
|
52 |
+
### Pre-trained models
|
53 |
+
Should you wish to download the pre-trained models to run the notebooks, please first download them with:
|
54 |
+
|
55 |
+
```bash
|
56 |
+
wget -r -np -nH --cut-dirs=2 -R *index* http://eecs.qmul.ac.uk/~jo001/PandA-pretrained-models/
|
57 |
+
```
|
58 |
+
|
59 |
+
# citation
|
60 |
+
|
61 |
+
If you find our work useful, please consider citing our paper:
|
62 |
+
|
63 |
+
```bibtex
|
64 |
+
@inproceedings{oldfield2023panda,
|
65 |
+
title={PandA: Unsupervised Learning of Parts and Appearances in the Feature Maps of GANs},
|
66 |
+
author={James Oldfield and Christos Tzelepis and Yannis Panagakis and Mihalis A. Nicolaou and Ioannis Patras},
|
67 |
+
booktitle={Int. Conf. Learn. Represent.},
|
68 |
+
year={2023}
|
69 |
+
}
|
70 |
+
```
|
71 |
+
|
72 |
+
# contact
|
73 |
+
|
74 |
+
**Please feel free to get in touch at**: `j.a.oldfield@qmul.ac.uk`
|
75 |
+
|
76 |
+
---
|
77 |
+
|
78 |
+
## credits
|
79 |
+
|
80 |
+
- `./networks/genforce/` contains mostly code directly from [https://github.com/genforce/genforce](https://github.com/genforce/genforce).
|
81 |
+
- `./networks/biggan/` contains mostly code directly from [https://github.com/huggingface/pytorch-pretrained-BigGAN](https://github.com/huggingface/pytorch-pretrained-BigGAN).
|
82 |
+
- `./networks/stylegan3/` contains mostly code directly from [https://github.com/NVlabs/stylegan3](https://github.com/NVlabs/stylegan3).
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
boto3==1.21.8
|
2 |
+
botocore==1.24.8
|
3 |
+
bs4
|
4 |
+
imageio==2.14.1
|
5 |
+
jupyterlab
|
6 |
+
matplotlib==3.5.1
|
7 |
+
nltk==3.7
|
8 |
+
numpy
|
9 |
+
opencv-python==4.2.0.32
|
10 |
+
Pillow==9.0.1
|
11 |
+
requests==2.27.0
|
12 |
+
tensorly==0.7.0
|
13 |
+
torch==1.10.2
|
14 |
+
torchvision==0.11.3
|
utils.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from matplotlib import pyplot as plt
|
2 |
+
from matplotlib import gridspec
|
3 |
+
import matplotlib.patches as mpatches
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
|
9 |
+
def get_cols():
|
10 |
+
# list of perceptually distinct colours (for spatial factor plots)
|
11 |
+
return np.array([[255,0,0], [255,255,0], [0,234,255], [170,0,255], [255,127,0], [191,255,0], [0,149,255], [255,0,170], [255,212,0], [106,255,0], [0,64,255], [237,185,185], [185,215,237], [231,233,185], [220,185,237], [185,237,224], [143,35,35], [35,98,143], [143,106,35], [107,35,143], [79,143,35], [0,0,0], [115,115,115], [204,204,204]])
|
12 |
+
|
13 |
+
|
14 |
+
def mapRange(value, inMin, inMax, outMin, outMax):
|
15 |
+
return outMin + (((value - inMin) / (inMax - inMin)) * (outMax - outMin))
|
16 |
+
|
17 |
+
|
18 |
+
def plot_masks(Us, r, s, rs=256, save_path=None, title_factors=True):
|
19 |
+
"""
|
20 |
+
Plots the parts factors with matplotlib for visualization
|
21 |
+
|
22 |
+
Parameters
|
23 |
+
----------
|
24 |
+
Us : np.array
|
25 |
+
Learnt parts factor matrix.
|
26 |
+
r : int
|
27 |
+
Number of factors to show.
|
28 |
+
s : int
|
29 |
+
Dimensions of each part (h*w).
|
30 |
+
rs : int
|
31 |
+
Target size to downsize images to.
|
32 |
+
save_path : bool
|
33 |
+
Save figure?
|
34 |
+
title_factors : bool
|
35 |
+
Print matplotlib title on each part?
|
36 |
+
"""
|
37 |
+
|
38 |
+
fig = plt.figure(constrained_layout=True, figsize=(20, 3))
|
39 |
+
spec = gridspec.GridSpec(ncols=r + 1, nrows=1, figure=fig)
|
40 |
+
|
41 |
+
for i in range(0, r):
|
42 |
+
fig.add_subplot(spec[i])
|
43 |
+
|
44 |
+
if title_factors:
|
45 |
+
plt.title(f'Part {i}')
|
46 |
+
|
47 |
+
part = Us[i].reshape([s, s])
|
48 |
+
part = mapRange(part, torch.min(part), torch.max(part), 0.0, 1.0) * 255
|
49 |
+
part = part.detach().cpu().numpy()
|
50 |
+
part = np.array(Image.fromarray(np.uint8(part)).convert('RGBA').resize((rs, rs), Image.NEAREST)) / 255
|
51 |
+
|
52 |
+
plt.axis('off')
|
53 |
+
plt.imshow(part, vmin=1, vmax=1, cmap='gray', alpha=1.00)
|
54 |
+
|
55 |
+
if save_path is not None:
|
56 |
+
plt.savefig(save_path)
|
57 |
+
|
58 |
+
|
59 |
+
def plot_colours(image, Us, r, s, rs=128, save_path=None, alpha=1.0, seed=-1, legend=True):
|
60 |
+
"""
|
61 |
+
Plots the parts factors over an image with matplotlib for visualization
|
62 |
+
|
63 |
+
Parameters
|
64 |
+
----------
|
65 |
+
image : np.array
|
66 |
+
Image to visualize.
|
67 |
+
Us : np.array
|
68 |
+
Learnt parts factor matrix.
|
69 |
+
r : int
|
70 |
+
Number of factors to show.
|
71 |
+
s : int
|
72 |
+
Dimensions of each part (h*w).
|
73 |
+
rs : int
|
74 |
+
Target size to downsize images to.
|
75 |
+
alpha : float
|
76 |
+
Alpha value for the masks.
|
77 |
+
seed : int
|
78 |
+
Random seed when generating the colour palette (use -1 to use the provided "perceptually distinct" colour palette, but note this has a maximum of 30 colours or so).
|
79 |
+
legend : bool
|
80 |
+
Plot the legend, detailing the colour-coded parts key?
|
81 |
+
"""
|
82 |
+
|
83 |
+
img = Image.fromarray(image).resize((rs, rs)).convert('RGBA')
|
84 |
+
|
85 |
+
# Use perceptually distinct colour list, or random seed (for e.g. if you have too many factors)
|
86 |
+
cols = get_cols()
|
87 |
+
if seed >= 0:
|
88 |
+
np.random.seed(seed)
|
89 |
+
cols = np.random.randint(0, 255, [r, 3])
|
90 |
+
|
91 |
+
plt.imshow(img, alpha=1.0)
|
92 |
+
plt.axis('off')
|
93 |
+
|
94 |
+
patches = []
|
95 |
+
for i in range(0, r):
|
96 |
+
mask = Us[i].detach().cpu().numpy().reshape([s, s])
|
97 |
+
mask = mapRange(mask, np.min(mask), np.max(mask), 0, 255)
|
98 |
+
mask = np.uint8(mask)
|
99 |
+
mask = np.array(Image.fromarray(mask).convert('L').resize((rs, rs)))
|
100 |
+
mask = (mask[:, :, None] / 255.) * np.array(np.concatenate([cols[i] / 255, [1]]))
|
101 |
+
|
102 |
+
patches += [mpatches.Patch(color=cols[i] / 255, label=f'Part {i}')]
|
103 |
+
|
104 |
+
plt.imshow(mask, vmin=0, vmax=1, alpha=alpha)
|
105 |
+
|
106 |
+
if legend:
|
107 |
+
plt.legend(title='Spatial factors', handles=patches, bbox_to_anchor=(1.01, 1.01), loc="upper left")
|
108 |
+
|
109 |
+
if save_path is not None:
|
110 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0)
|