james-oldfield commited on
Commit
ceb80dd
1 Parent(s): 2a76164

Upload 18 files

Browse files
annotated_directions.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated_directions = {
2
+ 'stylegan2_ffhq1024': {
3
+ # Directions used in paper with a single decomposition:
4
+ 'big_eyes': {
5
+ 'parameters': [7, 6, 30], # used in main paper
6
+ 'layer': 5,
7
+ 'ranks': [512, 8],
8
+ 'checkpoints_path': [
9
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_5-rank_8.npy',
10
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_5-rank_512.npy',
11
+ ],
12
+ },
13
+ 'long_nose': {
14
+ 'parameters': [5, 82, 30], # used in main paper
15
+ 'layer': 5,
16
+ 'ranks': [512, 8],
17
+ 'checkpoints_path': [
18
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_5-rank_8.npy',
19
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_5-rank_512.npy',
20
+ ],
21
+ },
22
+ 'smile': {
23
+ 'parameters': [4, 46, -30], # used in sup. material
24
+ 'layer': 5,
25
+ 'ranks': [512, 8],
26
+ 'checkpoints_path': [
27
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_5-rank_8.npy',
28
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_5-rank_512.npy',
29
+ ],
30
+ },
31
+ 'open_mouth': {
32
+ 'parameters': [4, 39, 30], # used in sup. material
33
+ 'layer': 5,
34
+ 'ranks': [512, 8],
35
+ 'checkpoints_path': [
36
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_5-rank_8.npy',
37
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_5-rank_512.npy',
38
+ ],
39
+ },
40
+
41
+ # Additional directions
42
+ 'big_eyeballs': {
43
+ 'parameters': [8, 27, 100],
44
+ 'layer': 6,
45
+ 'ranks': [512, 16],
46
+ 'checkpoints_path': [
47
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_6-rank_16.npy',
48
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_6-rank_512.npy',
49
+ ],
50
+ },
51
+ 'wide_nose': {
52
+ 'parameters': [15, 13, 100],
53
+ 'layer': 6,
54
+ 'ranks': [512, 16],
55
+ 'checkpoints_path': [
56
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_6-rank_16.npy',
57
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_6-rank_512.npy',
58
+ ],
59
+ },
60
+ 'glance_left': {
61
+ 'parameters': [8, 281, 50],
62
+ 'layer': 6,
63
+ 'ranks': [512, 16],
64
+ 'checkpoints_path': [
65
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_6-rank_16.npy',
66
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_6-rank_512.npy',
67
+ ],
68
+ },
69
+ 'glance_right': {
70
+ 'parameters': [8, 281, -70],
71
+ 'layer': 6,
72
+ 'ranks': [512, 16],
73
+ 'checkpoints_path': [
74
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_6-rank_16.npy',
75
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_6-rank_512.npy',
76
+ ],
77
+ },
78
+ 'bald_forehead': {
79
+ 'parameters': [3, 25, 100],
80
+ 'layer': 6,
81
+ 'ranks': [512, 16],
82
+ 'checkpoints_path': [
83
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_6-rank_16.npy',
84
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_6-rank_512.npy',
85
+ ],
86
+ },
87
+ 'light_eyebrows': {
88
+ 'parameters': [8, 4, 30],
89
+ 'layer': 7,
90
+ 'ranks': [512, 16],
91
+ 'checkpoints_path': [
92
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
93
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
94
+ ],
95
+ },
96
+ 'dark_eyebrows': {
97
+ 'parameters': [8, 9, 30],
98
+ 'layer': 7,
99
+ 'ranks': [512, 16],
100
+ 'checkpoints_path': [
101
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
102
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
103
+ ],
104
+ },
105
+ 'no_eyebrows': {
106
+ 'parameters': [8, 4, 50],
107
+ 'layer': 7,
108
+ 'ranks': [512, 16],
109
+ 'checkpoints_path': [
110
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
111
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
112
+ ],
113
+ },
114
+ 'dark_eyes': {
115
+ 'parameters': [11, 176, 50],
116
+ 'layer': 7,
117
+ 'ranks': [512, 16],
118
+ 'checkpoints_path': [
119
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
120
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
121
+ ],
122
+ },
123
+ 'red_eyes': {
124
+ 'parameters': [11, 109, 60],
125
+ 'layer': 7,
126
+ 'ranks': [512, 16],
127
+ 'checkpoints_path': [
128
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
129
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
130
+ ],
131
+ },
132
+ 'eyes_short': {
133
+ 'parameters': [11, 262, 70],
134
+ 'layer': 7,
135
+ 'ranks': [512, 16],
136
+ 'checkpoints_path': [
137
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
138
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
139
+ ],
140
+ },
141
+ 'eyes_open': {
142
+ 'parameters': [11, 28, 50],
143
+ 'layer': 7,
144
+ 'ranks': [512, 16],
145
+ 'checkpoints_path': [
146
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
147
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
148
+ ],
149
+ },
150
+ 'eyes_close': {
151
+ 'parameters': [11, 398, 80],
152
+ 'layer': 7,
153
+ 'ranks': [512, 16],
154
+ 'checkpoints_path': [
155
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
156
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
157
+ ],
158
+ },
159
+ 'no_eyes': {
160
+ 'parameters': [11, 0, -200],
161
+ 'layer': 7,
162
+ 'ranks': [512, 16],
163
+ 'checkpoints_path': [
164
+ './checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy',
165
+ './checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy',
166
+ ],
167
+ },
168
+
169
+ },
170
+
171
+ }
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+ from model import Model as Model
6
+ from annotated_directions import annotated_directions
7
+ device = torch.device('cpu')
8
+
9
+ torch.set_grad_enabled(False)
10
+ model_name = "stylegan2_ffhq1024"
11
+
12
+ directions = list(annotated_directions[model_name].keys())
13
+
14
+
15
+ def inference(seed, direction):
16
+ layer = annotated_directions[model_name][direction]['layer']
17
+ M = Model(model_name, trunc_psi=1.0, device=device, layer=layer)
18
+ M.ranks = annotated_directions[model_name][direction]['ranks']
19
+
20
+ # load the checkpoint
21
+ try:
22
+ M.Us = torch.Tensor(np.load(annotated_directions[model_name][direction]['checkpoints_path'][0])).to(device)
23
+ M.Uc = torch.Tensor(np.load(annotated_directions[model_name][direction]['checkpoints_path'][1])).to(device)
24
+ except KeyError:
25
+ raise KeyError('ERROR: No directions specified in ./annotated_directions.py for this model')
26
+
27
+ part, appearance, lam = annotated_directions[model_name][direction]['parameters']
28
+
29
+ Z, image, image2, part_img = M.edit_at_layer([[part]], [appearance], [lam], t=seed, Uc=M.Uc, Us=M.Us, noise=None)
30
+
31
+ dif = np.tile(((np.mean((image - image2)**2, -1)))[:,:,None], [1,1,3]).astype(np.uint8)
32
+
33
+ return Image.fromarray(np.concatenate([image, image2, dif], 1))
34
+
35
+
36
+ demo = gr.Interface(
37
+ fn=inference,
38
+ inputs=[gr.Slider(0, 1000, value=64), gr.Dropdown(directions, value='no_eyebrows')],
39
+ outputs=[gr.Image(type="pil", label="original | edited | mean-squared difference")],
40
+ title="PandA (ICLR'23) - FFHQ edit zoo",
41
+ description="Provides a quick interface to manipulate pre-annotated directions with pre-trained global parts and appearances factors. Note that we use the free CPU tier, so synthesis takes about 10 seconds.",
42
+ article="Check out the full demo and paper at: <a href='https://github.com/james-oldfield/PandA'>https://github.com/james-oldfield/PandA</a>"
43
+ )
44
+ demo.launch()
checkpoints/Uc-name_stylegan2_afhqdog512-layer_8-rank_512.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3e40871d2c686e6552b44dd6e652a60c1596beb9f7a9d5ffb9d9ffdd5e97e53
3
+ size 1048704
checkpoints/Uc-name_stylegan2_ffhq1024-layer_5-rank_512.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e11bfd6bd179944dd89b6ff6e63829e595bdb7f926a4ff4cbc23e7adc91abb4
3
+ size 1048704
checkpoints/Uc-name_stylegan2_ffhq1024-layer_6-rank_512.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55ed230d4f0dbaf1ed7849ada16cb95b29989efcd025c1dd96eda8c99245f287
3
+ size 1048704
checkpoints/Uc-name_stylegan2_ffhq1024-layer_7-rank_512.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7a2f1d9d3034bcec65aa5463986a1ddab35ba73d9e3c639499e02f33e58ece3
3
+ size 1048704
checkpoints/Us-name_stylegan2_afhqdog512-layer_8-rank_8.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4808e4fc3ab102d53d362ab5a783ac09a994cf826a6394e569b08148ffc985ca
3
+ size 131200
checkpoints/Us-name_stylegan2_ffhq1024-layer_5-rank_8.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a154209cd3757fe5912e0530a1431e3aac67b2f34e73c6b73326af4a9f6ce0f4
3
+ size 8320
checkpoints/Us-name_stylegan2_ffhq1024-layer_6-rank_16.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6b7ef7e612b8b1840cf83b9e6f88438cd6858665e95cd4758ae0e5e8f06e374
3
+ size 65664
checkpoints/Us-name_stylegan2_ffhq1024-layer_7-rank_16.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fdb99dcf6d5ccbad8e9da5709a30837b16c01686f5d68fd1922a8e8fb97c9e23
3
+ size 65664
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
environment.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: PandA
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.8.5
7
+ - pip=20.3
8
+ - cudatoolkit=11.3
9
+ - pytorch=1.11.0
10
+ - torchvision=0.12.0
11
+ - numpy=1.19.2
12
+ - pip:
13
+ - tensorly==0.7.0
14
+ - opencv-python==4.1.2.30
15
+ - boto3==1.21.8
16
+ - botocore==1.24.8
17
+ - imageio==2.14.1
18
+ - matplotlib==3.5.1
19
+ - nltk==3.7
20
+ - numpy
21
+ - Pillow==9.0.1
22
+ - requests==2.27.0
23
+ - bs4
24
+ - ipykernel
25
+ - jupyterlab
ffhq-edit-zoo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
localize-concepts.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
model.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from networks.load_generator import load_generator
2
+ from networks.genforce.utils.visualizer import postprocess_image as postprocess
3
+ from networks.biggan import one_hot_from_names, truncated_noise_sample
4
+ from networks.stylegan3.load_stylegan3 import make_transform
5
+
6
+ from matplotlib import pyplot as plt
7
+
8
+ from utils import plot_masks, plot_colours, mapRange
9
+
10
+ import torch
11
+ import numpy as np
12
+ from PIL import Image
13
+ import tensorly as tl
14
+ tl.set_backend('pytorch')
15
+
16
+
17
+ class Model():
18
+ def __init__(self, model_name, t=0, layer=5, trunc_psi=1.0, trunc_layers=18, device='cuda', biggan_classes=['fox']):
19
+ """
20
+ Instantiate the model for decomposition and/or local image editing.
21
+
22
+ Parameters
23
+ ----------
24
+ model_name : string
25
+ Name of architecture and dataset--one of the items in ./networks/genforce/models/model_zoo.py.
26
+ t : int
27
+ Random seed for the generator (to generate a sample image).
28
+ layer : int
29
+ Intermediate layer at which to perform the decomposition.
30
+ trunc_psi : float
31
+ Truncation value in [0, 1].
32
+ trunc_layers : int
33
+ Number of layers at which to apply truncation.
34
+ device : string
35
+ Device to store the tensors on.
36
+ biggan_classes : list
37
+ List of strings specifying imagenet classes of interest (e.g. ['alp', 'breakwater']).
38
+ """
39
+ self.gan_type = model_name.split('_')[0]
40
+ self.model_name = model_name
41
+ self.randomize_noise = False
42
+ self.device = device
43
+ self.biggan_classes = biggan_classes
44
+ self.layer = layer # layer to decompose
45
+ self.start = 0 if 'stylegan2' in self.gan_type else 2
46
+ self.trunc_psi = trunc_psi
47
+ self.trunc_layers = trunc_layers
48
+
49
+ self.generator = load_generator(model_name, device)
50
+ noise = torch.Tensor(np.random.randn(1, self.generator.z_space_dim)).to(self.device)
51
+ z, image = self.sample(noise, layer=layer, trunc_psi=trunc_psi, trunc_layers=trunc_layers, verbose=True)
52
+
53
+ self.c = z.shape[1]
54
+ self.s = z.shape[2]
55
+ self.image = image
56
+
57
+ def HOSVD(self, batch_size=10, n_iters=100):
58
+ """
59
+ Initialises the appearance basis A. In particular, computes the left-singular vectors of the channel mode's scatter matrix.
60
+
61
+ Note: total samples used is batch_size * n_iters
62
+
63
+ Parameters
64
+ ----------
65
+ batch_size : int
66
+ Number of activations to sample in a single go.
67
+ n_iters : int
68
+ Number of times to sample `batch_size`-many activations.
69
+ """
70
+ np.random.seed(0)
71
+ torch.manual_seed(0)
72
+
73
+ with torch.no_grad():
74
+ Z = torch.zeros((batch_size * n_iters, self.c, self.s, self.s), device=self.device)
75
+
76
+ # note: perform in loops to have a larger effective batch size
77
+ print('Starting loops...')
78
+ for i in range(n_iters):
79
+ np.random.seed(i)
80
+ torch.manual_seed(i)
81
+ noise = torch.Tensor(np.random.randn(batch_size, self.generator.z_space_dim)).to(self.device)
82
+ z, _ = self.sample(noise, layer=self.layer, partial=True)
83
+
84
+ Z[(batch_size * i):(batch_size * (i + 1))] = z
85
+
86
+ Z = Z.view([-1, self.c, self.s**2])
87
+ print(f'Generated {batch_size * n_iters} gan samples...')
88
+
89
+ scat = 0
90
+ for _, x in enumerate(Z):
91
+ # mode-3 unfolding in the paper, but in PyTorch channel mode is first.
92
+ m_unfold = tl.unfold(x, 0)
93
+ scat += m_unfold @ m_unfold.T
94
+
95
+ self.Uc_init, _, _ = np.linalg.svd((scat / len(Z)).cpu().numpy())
96
+ self.Uc_init = torch.Tensor(self.Uc_init).to(self.device)
97
+
98
+ print('... HOSVD done')
99
+
100
+ def decompose(self, ranks=[512, 8], lr=1e-8, batch_size=1, its=10000, log_modulo=1000, hosvd_init=True, stochastic=True, n_iters=1, verbose=True):
101
+ """
102
+ Performs the decomposition in the paper. In particular, Algorithm 1.,
103
+ either with a non-fixed batch of samples (stochastic=True), or descends the full gradients.
104
+
105
+ Parameters
106
+ ----------
107
+ ranks : list
108
+ List of integers specifying the R_C and R_S, the ranks--i.e. number of parts and appearances respectively.
109
+ lr : float
110
+ Learning rate the projected gradient descent.
111
+ batch_size : int
112
+ Number of samples in each batch.
113
+ its : int
114
+ Total number of iterations.
115
+ log_modulo : int
116
+ Parameter used to control how often "training" information is displayed.
117
+ hosvd_init : bool
118
+ Initialise appearance factors from HOSVD? (else from random normal).
119
+ stochastic : bool
120
+ Sample the batch again each iteration? Else descent full gradients
121
+ n_iters : int
122
+ Number of `batch_size`-many samples to take (for full gradient).
123
+ The total activations are sampled in batches in a loop to enable it to fit in memory.
124
+ verbose : bool
125
+ Prints extra information.
126
+ """
127
+ self.ranks = ranks
128
+ np.random.seed(0)
129
+ torch.manual_seed(0)
130
+
131
+ #######################
132
+ # init from HOSVD, else random normal
133
+ Uc = self.Uc_init[:, :ranks[0]].detach().clone().to(self.device) if hosvd_init else torch.randn(self.Uc_init.shape[0], ranks[0]).detach().clone().to(self.device) * 0.01
134
+ Us = torch.Tensor(np.random.uniform(0, 0.01, size=[self.s**2, ranks[1]])).to(self.device)
135
+ #######################
136
+
137
+ print(f'Uc shape: {Uc.shape}, Us shape: {Us.shape}')
138
+
139
+ with torch.no_grad():
140
+ zeros = torch.zeros_like(Us, device=self.device)
141
+ Us = torch.maximum(Us, zeros)
142
+
143
+ # use a fixed batch (i.e. descend the full gradient)
144
+ if not stochastic:
145
+ Z = torch.zeros((batch_size * n_iters, self.c, self.s, self.s), device=self.device)
146
+
147
+ # note: perform in loops to have a larger effective batch size
148
+ print(f'Starting loops, total Z shape: {Z.shape}...')
149
+ for i in range(n_iters):
150
+ np.random.seed(i)
151
+ torch.manual_seed(i)
152
+ noise = torch.Tensor(np.random.randn(batch_size, self.generator.z_space_dim)).to(self.device)
153
+ z, _ = self.sample(noise, layer=self.layer, partial=True)
154
+
155
+ Z[(batch_size * i):(batch_size * (i + 1))] = z
156
+
157
+ for t in range(its):
158
+ np.random.seed(t)
159
+ torch.manual_seed(t)
160
+
161
+ # resample the batch, if stochastic
162
+ if stochastic:
163
+ noise = torch.Tensor(np.random.randn(batch_size, self.generator.z_space_dim)).to(self.device)
164
+ Z, _ = self.sample(noise, layer=self.layer, partial=True)
165
+
166
+ if verbose:
167
+ # reconstruct (for visualisation)
168
+ coords = tl.tenalg.multi_mode_dot(Z.view(-1, self.c, self.s**2).float(), [Uc.T, Us.T], transpose=False, modes=[1, 2])
169
+ Z_rec = tl.tenalg.multi_mode_dot(coords, [Uc, Us], transpose=False, modes=[1, 2])
170
+
171
+ self.rec_loss = torch.mean(torch.norm(Z.view(-1, self.c, self.s**2).float() - Z_rec, p='fro', dim=[1, 2]) ** 2)
172
+
173
+ # Update S
174
+ z = Z.view(-1, self.c, self.s**2).float()
175
+ Us_g = -4 * (torch.transpose(z,1,2)@Uc@Uc.T@z@Us) + \
176
+ 2 * (Us@Us.T@torch.transpose(z,1,2)@Uc@Uc.T@Uc@Uc.T@z@Us + torch.transpose(z,1,2)@Uc@Uc.T@Uc@Uc.T@z@Us@Us.T@Us)
177
+ Us_g = torch.sum(Us_g, 0)
178
+
179
+ Us = Us - lr * Us_g
180
+ # --- projection step ---a
181
+ Us = torch.maximum(Us, zeros)
182
+
183
+ # Update C
184
+ Uc_g = -4 * (z@Us@Us.T@torch.transpose(z,1,2)@Uc) + \
185
+ 2 * (Uc@Uc.T@z@Us@Us.T@Us@Us.T@torch.transpose(z,1,2)@Uc + z@Us@Us.T@Us@Us.T@torch.transpose(z,1,2)@Uc@Uc.T@Uc)
186
+ Uc_g = torch.sum(Uc_g, 0)
187
+ Uc = Uc - lr * Uc_g
188
+
189
+ self.Us = Us
190
+ self.Uc = Uc
191
+
192
+ if t % log_modulo == 0 and verbose:
193
+ print(f'ITERATION: {t}')
194
+ z, x = self.sample(noise, layer=self.layer, partial=False)
195
+
196
+ # here we display the learnt parts factors and also overlay them over the images to visualise.
197
+ plot_masks(Us.T, r=min(ranks[-1], 32), s=self.s)
198
+ plt.show()
199
+ plot_colours(x, Us.T, r=ranks[-1], s=self.s, seed=-1)
200
+ plt.show()
201
+
202
+ def decompose_autograd(self, ranks=[512, 8], lr=1e-8, batch_size=1, its=10000, log_modulo=1000, verbose=True, hosvd_init=True):
203
+ """
204
+ Performs the same decomposition in the paper, only uses autograd with Adam optimizer (and projected gradient descent).
205
+
206
+ Parameters
207
+ ----------
208
+ ranks : list
209
+ List of integers specifying the R_C and R_S, the ranks--i.e. number of parts and appearances respectively.
210
+ lr : float
211
+ Learning rate the projected gradient descent.
212
+ batch_size : int
213
+ Number of samples in each batch.
214
+ its : int
215
+ Total number of iterations.
216
+ log_modulo : int
217
+ Parameter used to control how often "training" information is displayed.
218
+ hosvd_init : bool
219
+ Initialise appearance factors from HOSVD? (else from random normal).
220
+ verbose : bool
221
+ Prints extra information.
222
+ """
223
+ self.ranks = ranks
224
+ np.random.seed(0)
225
+ torch.manual_seed(0)
226
+
227
+ #######################
228
+ # init from HOSVD, else random normal
229
+ Uc = torch.nn.Parameter(self.Uc_init[:, :ranks[0]].detach().clone().to(self.device), requires_grad=True) \
230
+ if hosvd_init else torch.nn.Parameter(torch.randn(self.Uc_init.shape[0], ranks[0]).detach().clone().to(self.device) * 0.01)
231
+ Us = torch.nn.Parameter(torch.Tensor(np.random.uniform(0, 0.01, size=[self.s**2, ranks[1]])).to(self.device), requires_grad=True)
232
+ #######################
233
+ optimizerS = torch.optim.Adam([Us], lr=lr)
234
+ optimizerC = torch.optim.Adam([Uc], lr=lr)
235
+
236
+ print(f'Uc shape: {Uc.shape}, Us shape: {Us.shape}')
237
+
238
+ zeros = torch.zeros_like(Us, device=self.device)
239
+ for t in range(its):
240
+ np.random.seed(t)
241
+ torch.manual_seed(t)
242
+
243
+ noise = torch.Tensor(np.random.randn(batch_size, self.generator.z_space_dim)).to(self.device)
244
+ Z, _ = self.sample(noise, layer=self.layer, partial=True)
245
+
246
+ # Update S
247
+ # reconstruct
248
+ coords = tl.tenalg.multi_mode_dot(Z.view(-1, self.c, self.s**2).float(), [Uc.T, Us.T], transpose=False, modes=[1, 2])
249
+ Z_rec = tl.tenalg.multi_mode_dot(coords, [Uc, Us], transpose=False, modes=[1, 2])
250
+
251
+ rec_loss = torch.mean(torch.norm(Z.view(-1, self.c, self.s**2).float() - Z_rec, p='fro', dim=[1, 2]) ** 2)
252
+ rec_loss.backward(retain_graph=True)
253
+
254
+ optimizerS.step()
255
+ # --- projection step ---
256
+ Us.data = torch.maximum(Us.data, zeros)
257
+ optimizerS.zero_grad()
258
+ optimizerC.zero_grad()
259
+
260
+ # Update C
261
+ # reconstruct with updated Us
262
+ coords = tl.tenalg.multi_mode_dot(Z.view(-1, self.c, self.s**2).float(), [Uc.T, Us.T], transpose=False, modes=[1, 2])
263
+ Z_rec = tl.tenalg.multi_mode_dot(coords, [Uc, Us], transpose=False, modes=[1, 2])
264
+
265
+ rec_loss = torch.mean(torch.norm(Z.view(-1, self.c, self.s**2).float() - Z_rec, p='fro', dim=[1, 2]) ** 2)
266
+ rec_loss.backward()
267
+ optimizerC.step()
268
+ optimizerS.zero_grad()
269
+ optimizerC.zero_grad()
270
+
271
+ self.Us = Us
272
+ self.Uc = Uc
273
+
274
+ with torch.no_grad():
275
+ if t % log_modulo == 0 and verbose:
276
+ print(f'Iteration {t} -- rec {rec_loss}')
277
+
278
+ noise = torch.Tensor(np.random.randn(batch_size, self.generator.z_space_dim)).to(self.device)
279
+ Z, x = self.sample(noise, layer=self.layer, partial=False)
280
+
281
+ plot_masks(Us.T, r=min(ranks[-1], 32), s=self.s)
282
+ plt.show()
283
+ plot_colours(x, Us.T, r=ranks[-1], s=self.s, seed=-1)
284
+ plt.show()
285
+
286
+ def refine(self, Z, image, lr=1e-8, its=1000, log_modulo=250, verbose=True):
287
+ """
288
+ Performs the "refinement" step described in the paper, for a given sample Z.
289
+
290
+ Parameters
291
+ ----------
292
+ Z : torch.Tensor
293
+ Intermediate activations for target refinement.
294
+ image : np.array
295
+ Corresponding image for Z (purely for visualisation purposes).
296
+ lr : float
297
+ Learning rate the projected gradient descent.
298
+ its : int
299
+ Total number of iterations.
300
+ log_modulo : int
301
+ Parameter used to control how often "training" information is displayed.
302
+ verbose : bool
303
+ Prints extra information.
304
+
305
+ Returns
306
+ -------
307
+ UsR : torch.Tensor
308
+ The refined factors \tilde{P}_i.
309
+ """
310
+ np.random.seed(0)
311
+ torch.manual_seed(0)
312
+
313
+ #######################
314
+ # init from global spatial factors
315
+ UsR = self.Us.clone()
316
+ Uc = self.Uc
317
+ #######################
318
+
319
+ zeros = torch.zeros_like(self.Us, device=self.device)
320
+ for t in range(its):
321
+ with torch.no_grad():
322
+ z = Z.view(-1, self.c, self.s**2).float()
323
+ # descend refinement term's gradient
324
+ UsR_g = -4 * (torch.transpose(z,1,2)@Uc@Uc.T@z@UsR) + \
325
+ 2 * (UsR@UsR.T@torch.transpose(z,1,2)@Uc@Uc.T@Uc@Uc.T@z@UsR + torch.transpose(z,1,2)@Uc@Uc.T@Uc@Uc.T@z@UsR@UsR.T@UsR)
326
+ UsR_g = torch.sum(UsR_g, 0)
327
+
328
+ # Update S
329
+ UsR = UsR - lr * UsR_g
330
+ # PGD step
331
+ UsR = torch.maximum(UsR, zeros)
332
+
333
+ if ((t + 1) % log_modulo == 0 and verbose):
334
+ print(f'iteration {t}')
335
+
336
+ plot_masks(UsR.T, s=self.s, r=min(self.ranks[-1], 16))
337
+ plt.show()
338
+ plot_colours(image, UsR.T, s=self.s, r=self.ranks[-1], seed=-1, alpha=0.9)
339
+ plt.show()
340
+
341
+ return UsR
342
+
343
+ def edit_at_layer(self, part, appearance, lam, t, Uc, Us, noise=None, b_idx=0):
344
+ """
345
+ Performs the "refinement" step described in the paper, for a given sample Z.
346
+
347
+ Parameters
348
+ ----------
349
+ part : list
350
+ List of ints containing the part(s) (column of Us) at which to edit.
351
+ appearance : list
352
+ List of ints containing the appearance (column of Uc) to apply at the corresponding part(s).
353
+ lam : list
354
+ List of ints containing the magnitude for each edit.
355
+ t : int
356
+ Random seed to edit
357
+ Uc : np.array
358
+ Learnt appearance factors
359
+ Us : np.array
360
+ Learnt parts factors
361
+ noise : np.array
362
+ If specified, the target latent code itself to edit (i.e. instead of providing than a random seed number).
363
+ b_idx : int
364
+ Index of biggan categories to use.
365
+
366
+ Returns
367
+ -------
368
+ Z : torch.Tensor
369
+ The intermediate activation at layer self.L
370
+ image : np.array
371
+ The original image for sample `t` or from latent code `noise`.
372
+ image2 : np.array
373
+ The edited image.
374
+ part : np.array
375
+ The part used to edit.
376
+ """
377
+ with torch.no_grad():
378
+ if noise is None:
379
+ np.random.seed(t)
380
+ torch.manual_seed(t)
381
+ noise = torch.Tensor(np.random.randn(1, self.generator.z_space_dim)).to(self.device)
382
+ else:
383
+ np.random.seed(0)
384
+ torch.manual_seed(0)
385
+
386
+ direc = 0
387
+ for i in range(len(appearance)):
388
+ a = Uc[:, appearance[i]]
389
+ p = torch.sum(Us[:, part[i]], dim=-1).reshape([self.s, self.s])
390
+ p = mapRange(p, torch.min(p), torch.max(p), 0.0, 1.0)
391
+
392
+ # here, we basically form a rank-1 "tensor", to add to the target sample's activations.
393
+ # intuitively, the non-zero spatial positions of the part are filled with the appearance vector.
394
+ direc += lam[i] * tl.tenalg.outer([a, p])
395
+
396
+ if self.gan_type in ['stylegan', 'stylegan2']:
397
+ noise = self.generator.mapping(noise)['w']
398
+ noise_trunc = self.generator.truncation(noise, trunc_psi=self.trunc_psi, trunc_layers=self.trunc_layers)
399
+
400
+ Z = self.generator.synthesis(noise_trunc, start=self.start, stop=self.layer)['x']
401
+
402
+ x = self.generator.synthesis(noise_trunc, x=Z, start=self.layer)['image']
403
+ x_prime = self.generator.synthesis(noise_trunc, x=Z + direc, start=self.layer)['image']
404
+ elif 'pggan' in self.gan_type:
405
+ Z = self.generator(noise, start=self.start, stop=self.layer)['x']
406
+
407
+ x = self.generator(Z, start=self.layer)['image']
408
+ x_prime = self.generator(Z + direc, start=self.layer)['image']
409
+ elif 'biggan' in self.gan_type:
410
+ print(f'Choosing a {self.biggan_classes[b_idx]}')
411
+ class_vector = torch.tensor(one_hot_from_names([self.biggan_classes[b_idx]]), device=self.device)
412
+ noise_vector = torch.tensor(truncated_noise_sample(truncation=self.trunc_psi, batch_size=1, seed=t), device=self.device)
413
+
414
+ result = self.generator(noise_vector, class_vector, self.trunc_psi, stop=self.layer)
415
+ Z, cond_vector = result['z'], result['cond_vector']
416
+ x = self.generator(Z, class_vector, self.trunc_psi, cond_vector=cond_vector, start=self.layer)['z']
417
+ x_prime = self.generator(Z + direc, class_vector, self.trunc_psi, cond_vector=cond_vector, start=self.layer)['z']
418
+ elif 'stylegan3' in self.gan_type:
419
+ label = torch.zeros([1, 0], device=self.device)
420
+ Z = self.generator(noise, label, stop=self.layer, truncation_psi=self.trunc_psi, noise_mode='const')
421
+
422
+ x = self.generator(noise, label, x=Z, start=self.layer, stop=None, truncation_psi=self.trunc_psi, noise_mode='const')
423
+ x_prime = self.generator(noise, label, x=Z + direc, start=self.layer, stop=None, truncation_psi=self.trunc_psi, noise_mode='const')
424
+
425
+ image = np.array(Image.fromarray(postprocess(x.cpu().numpy())[0]).resize((256, 256)))
426
+ image2 = np.array(Image.fromarray(postprocess(x_prime.cpu().numpy())[0]).resize((256, 256)))
427
+
428
+ part = np.array(Image.fromarray(p.detach().cpu().numpy() * 255).convert('RGB').resize((256, 256), Image.NEAREST))
429
+ return Z, image, image2, part
430
+
431
+ def sample(self, noise, layer=5, partial=False, trunc_psi=1.0, trunc_layers=18, verbose=False):
432
+ """
433
+ Samples intermediate feature maps and resulting image the desired generator.
434
+
435
+ Parameters
436
+ ----------
437
+ noise : np.array
438
+ (batch_size, z_dim)-dim random standard gaussian noise.
439
+ layer : int
440
+ Intermediate layer at which to return intermediate features.
441
+ partial : bool
442
+ Perform full forward pass, and return image too? or just intermediate activations at layer number `layer`?
443
+ trunc_psi : float
444
+ Truncation value in [0, 1].
445
+ trunc_layers : int
446
+ Number of layers at which to apply truncation.
447
+ biggan_classes : list
448
+ List of strings specifying imagenet classes of interest (e.g. ['alp', 'breakwater']).
449
+ verbose : bool
450
+ Print out additional information?
451
+
452
+ Returns
453
+ -------
454
+ Z : torch.Tensor
455
+ The intermediate activations of shape [C, H, W].
456
+ image : np.array
457
+ Output RGB image.
458
+ """
459
+ with torch.no_grad():
460
+ if self.gan_type in ['stylegan', 'stylegan2']:
461
+ noise = self.generator.mapping(noise)['w']
462
+ noise_trunc = self.generator.truncation(noise, trunc_psi=trunc_psi, trunc_layers=trunc_layers)
463
+ Z = self.generator.synthesis(noise_trunc, start=self.start, stop=layer)['x']
464
+ if not partial:
465
+ x = self.generator.synthesis(noise_trunc, x=Z, start=layer)['image']
466
+ elif 'pggan' in self.gan_type:
467
+ Z = self.generator(noise, start=self.start, stop=layer)['x']
468
+ if not partial:
469
+ x = self.generator(Z, start=layer)['image']
470
+ elif 'biggan' in self.gan_type:
471
+ if verbose:
472
+ print(f'Using BigGAN class names: {", ".join(self.biggan_classes)}')
473
+
474
+ class_vector = torch.tensor(one_hot_from_names(list(np.random.choice(self.biggan_classes, noise.shape[0])), batch_size=noise.shape[0]), device=self.device)
475
+ noise_vector = torch.tensor(truncated_noise_sample(truncation=self.trunc_psi, batch_size=noise.shape[0]), device=self.device)
476
+
477
+ result = self.generator(noise_vector, class_vector, self.trunc_psi, stop=layer)
478
+ Z = result['z']
479
+ cond_vector = result['cond_vector']
480
+
481
+ if not partial:
482
+ x = self.generator(Z, class_vector, self.trunc_psi, cond_vector=cond_vector, start=layer)['z']
483
+ elif 'stylegan3' in self.gan_type:
484
+ label = torch.zeros([noise.shape[0], 0], device=self.device)
485
+ if hasattr(self.generator.synthesis, 'input'):
486
+ m = np.linalg.inv(make_transform((0,0), 0))
487
+ self.generator.synthesis.input.transform.copy_(torch.from_numpy(m))
488
+
489
+ Z = self.generator(noise, label, x=None, start=0, stop=layer, truncation_psi=trunc_psi, noise_mode='const')
490
+ if not partial:
491
+ x = self.generator(noise, label, x=Z, start=layer, stop=None, truncation_psi=trunc_psi, noise_mode='const')
492
+
493
+ if verbose:
494
+ print(f'-- Partial Z shape at layer {layer}: {Z.shape}')
495
+
496
+ if partial:
497
+ return Z, None
498
+ else:
499
+ image = postprocess(x.detach().cpu().numpy())
500
+ image = np.array(Image.fromarray(image[0]).resize((256, 256)))
501
+ return Z, image
502
+
503
+ def save(self):
504
+ Uc_path = f'./checkpoints/Uc-name_{self.model_name}-layer_{self.layer}-rank_{self.ranks[0]}.npy'
505
+ Us_path = f'./checkpoints/Us-name_{self.model_name}-layer_{self.layer}-rank_{self.ranks[1]}.npy'
506
+
507
+ np.save(Us_path, self.Us.detach().cpu().numpy())
508
+ np.save(Uc_path, self.Uc.detach().cpu().numpy())
509
+
510
+ print(f'Saved factors to {Us_path}, {Uc_path}')
readme.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PandA: Unsupervised Learning of Parts and Appearances in the Feature Maps of GANs
2
+
3
+ ## [ [paper](https://openreview.net/pdf?id=iUdSB2kK9GY) | [project page](http://eecs.qmul.ac.uk/~jo001/PandA/) | [video](https://www.youtube.com/watch?v=1KY055goKP0) | [edit zoo](https://colab.research.google.com/github/james-oldfield/PandA/blob/main/ffhq-edit-zoo.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/james-oldfield/PandA/blob/main/demo.ipynb) ]
4
+
5
+ ![main.jpg](./images/main.jpg)
6
+
7
+ > **PandA: Unsupervised Learning of Parts and Appearances in the Feature Maps of GANs**<br>
8
+ > James Oldfield, Christos Tzelepis, Yannis Panagakis, Mihalis A. Nicolaou, and Ioannis Patras<br>
9
+ > *International Conference on Learning Representations (ICLR)*, 2023 <br>
10
+ > https://arxiv.org/abs/2206.00048 <br>
11
+ >
12
+ > **Abstract**: Recent advances in the understanding of Generative Adversarial Networks (GANs) have led to remarkable progress in visual editing and synthesis tasks, capitalizing on the rich semantics that are embedded in the latent spaces of pre-trained GANs. However, existing methods are often tailored to specific GAN architectures and are limited to either discovering global semantic directions that do not facilitate localized control, or require some form of supervision through manually provided regions or segmentation masks. In this light, we present an architecture-agnostic approach that jointly discovers factors representing spatial parts and their appearances in an entirely unsupervised fashion. These factors are obtained by applying a semi-nonnegative tensor factorization on the feature maps, which in turn enables context-aware local image editing with pixel-level control. In addition, we show that the discovered appearance factors correspond to saliency maps that localize concepts of interest, without using any labels. Experiments on a wide range of GAN architectures and datasets show that, in comparison to the state of the art, our method is far more efficient in terms of training time and, most importantly, provides much more accurate localized control.
13
+
14
+ ![cat-gif](./images/cat-eye-control.gif)
15
+ > An example of using our learnt appearances and semantic parts for local image editing.
16
+
17
+ ## Experiments
18
+
19
+ We provide a number of notebooks to reproduce the experiments in the paper and to explore the model. Please see the following notebooks:
20
+
21
+ # [`./demo.ipynb`](./demo.ipynb)
22
+
23
+ This notebook contains the code to learn the parts and appearance factors at a target layer in a target GAN. Contains code for local image editing using the learnt parts, and provides code for refining the parts factors.
24
+
25
+ | Local image editing (at the learnt semantic parts) | |
26
+ | :-- | :-- |
27
+ | ![image](./images/l8-t645-Rs16-Rc512-rTrue-lam[100]-p[6]-start-end.gif) | ![image](./images/l8-t16-Rs8-Rc512-rTrue-lam-150-p2-start-end.gif)
28
+
29
+ # [`./localize-concepts.ipynb`](./localize-concepts.ipynb)
30
+
31
+ Provides code to localize/visualize concepts of interest for a model/dataset of interest (setup for the "background" concept in `stylegan2_afhqdog512` as an example).
32
+
33
+ | Localizing the learnt "background" concept vector |
34
+ | :-- |
35
+ | ![image](./images/mask-bg-42.gif) ![image](./images/mask-bg-83.gif) ![image](./images/mask-bg-29.gif) |
36
+
37
+ # [`./ffhq-edit-zoo.ipynb`](./ffhq-edit-zoo.ipynb)
38
+
39
+ Quickly produce edits with annotated directions with pre-trained factors on FFHQ StyleGAN2.
40
+
41
+ | Local image editing: "Big eyes" |
42
+ | :-- |
43
+ | ![image](./images/qualitative.png) |
44
+
45
+ ## Setup
46
+
47
+ Should you wish to run the notebooks, please consult this section below:
48
+
49
+ ### Install
50
+ First, please install the dependencies with `pip install -r requirements.txt`, or alternatively with conda using `conda env create -f environment.yml`
51
+
52
+ ### Pre-trained models
53
+ Should you wish to download the pre-trained models to run the notebooks, please first download them with:
54
+
55
+ ```bash
56
+ wget -r -np -nH --cut-dirs=2 -R *index* http://eecs.qmul.ac.uk/~jo001/PandA-pretrained-models/
57
+ ```
58
+
59
+ # citation
60
+
61
+ If you find our work useful, please consider citing our paper:
62
+
63
+ ```bibtex
64
+ @inproceedings{oldfield2023panda,
65
+ title={PandA: Unsupervised Learning of Parts and Appearances in the Feature Maps of GANs},
66
+ author={James Oldfield and Christos Tzelepis and Yannis Panagakis and Mihalis A. Nicolaou and Ioannis Patras},
67
+ booktitle={Int. Conf. Learn. Represent.},
68
+ year={2023}
69
+ }
70
+ ```
71
+
72
+ # contact
73
+
74
+ **Please feel free to get in touch at**: `j.a.oldfield@qmul.ac.uk`
75
+
76
+ ---
77
+
78
+ ## credits
79
+
80
+ - `./networks/genforce/` contains mostly code directly from [https://github.com/genforce/genforce](https://github.com/genforce/genforce).
81
+ - `./networks/biggan/` contains mostly code directly from [https://github.com/huggingface/pytorch-pretrained-BigGAN](https://github.com/huggingface/pytorch-pretrained-BigGAN).
82
+ - `./networks/stylegan3/` contains mostly code directly from [https://github.com/NVlabs/stylegan3](https://github.com/NVlabs/stylegan3).
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ boto3==1.21.8
2
+ botocore==1.24.8
3
+ bs4
4
+ imageio==2.14.1
5
+ jupyterlab
6
+ matplotlib==3.5.1
7
+ nltk==3.7
8
+ numpy
9
+ opencv-python==4.2.0.32
10
+ Pillow==9.0.1
11
+ requests==2.27.0
12
+ tensorly==0.7.0
13
+ torch==1.10.2
14
+ torchvision==0.11.3
utils.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from matplotlib import pyplot as plt
2
+ from matplotlib import gridspec
3
+ import matplotlib.patches as mpatches
4
+ import torch
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+
9
+ def get_cols():
10
+ # list of perceptually distinct colours (for spatial factor plots)
11
+ return np.array([[255,0,0], [255,255,0], [0,234,255], [170,0,255], [255,127,0], [191,255,0], [0,149,255], [255,0,170], [255,212,0], [106,255,0], [0,64,255], [237,185,185], [185,215,237], [231,233,185], [220,185,237], [185,237,224], [143,35,35], [35,98,143], [143,106,35], [107,35,143], [79,143,35], [0,0,0], [115,115,115], [204,204,204]])
12
+
13
+
14
+ def mapRange(value, inMin, inMax, outMin, outMax):
15
+ return outMin + (((value - inMin) / (inMax - inMin)) * (outMax - outMin))
16
+
17
+
18
+ def plot_masks(Us, r, s, rs=256, save_path=None, title_factors=True):
19
+ """
20
+ Plots the parts factors with matplotlib for visualization
21
+
22
+ Parameters
23
+ ----------
24
+ Us : np.array
25
+ Learnt parts factor matrix.
26
+ r : int
27
+ Number of factors to show.
28
+ s : int
29
+ Dimensions of each part (h*w).
30
+ rs : int
31
+ Target size to downsize images to.
32
+ save_path : bool
33
+ Save figure?
34
+ title_factors : bool
35
+ Print matplotlib title on each part?
36
+ """
37
+
38
+ fig = plt.figure(constrained_layout=True, figsize=(20, 3))
39
+ spec = gridspec.GridSpec(ncols=r + 1, nrows=1, figure=fig)
40
+
41
+ for i in range(0, r):
42
+ fig.add_subplot(spec[i])
43
+
44
+ if title_factors:
45
+ plt.title(f'Part {i}')
46
+
47
+ part = Us[i].reshape([s, s])
48
+ part = mapRange(part, torch.min(part), torch.max(part), 0.0, 1.0) * 255
49
+ part = part.detach().cpu().numpy()
50
+ part = np.array(Image.fromarray(np.uint8(part)).convert('RGBA').resize((rs, rs), Image.NEAREST)) / 255
51
+
52
+ plt.axis('off')
53
+ plt.imshow(part, vmin=1, vmax=1, cmap='gray', alpha=1.00)
54
+
55
+ if save_path is not None:
56
+ plt.savefig(save_path)
57
+
58
+
59
+ def plot_colours(image, Us, r, s, rs=128, save_path=None, alpha=1.0, seed=-1, legend=True):
60
+ """
61
+ Plots the parts factors over an image with matplotlib for visualization
62
+
63
+ Parameters
64
+ ----------
65
+ image : np.array
66
+ Image to visualize.
67
+ Us : np.array
68
+ Learnt parts factor matrix.
69
+ r : int
70
+ Number of factors to show.
71
+ s : int
72
+ Dimensions of each part (h*w).
73
+ rs : int
74
+ Target size to downsize images to.
75
+ alpha : float
76
+ Alpha value for the masks.
77
+ seed : int
78
+ Random seed when generating the colour palette (use -1 to use the provided "perceptually distinct" colour palette, but note this has a maximum of 30 colours or so).
79
+ legend : bool
80
+ Plot the legend, detailing the colour-coded parts key?
81
+ """
82
+
83
+ img = Image.fromarray(image).resize((rs, rs)).convert('RGBA')
84
+
85
+ # Use perceptually distinct colour list, or random seed (for e.g. if you have too many factors)
86
+ cols = get_cols()
87
+ if seed >= 0:
88
+ np.random.seed(seed)
89
+ cols = np.random.randint(0, 255, [r, 3])
90
+
91
+ plt.imshow(img, alpha=1.0)
92
+ plt.axis('off')
93
+
94
+ patches = []
95
+ for i in range(0, r):
96
+ mask = Us[i].detach().cpu().numpy().reshape([s, s])
97
+ mask = mapRange(mask, np.min(mask), np.max(mask), 0, 255)
98
+ mask = np.uint8(mask)
99
+ mask = np.array(Image.fromarray(mask).convert('L').resize((rs, rs)))
100
+ mask = (mask[:, :, None] / 255.) * np.array(np.concatenate([cols[i] / 255, [1]]))
101
+
102
+ patches += [mpatches.Patch(color=cols[i] / 255, label=f'Part {i}')]
103
+
104
+ plt.imshow(mask, vmin=0, vmax=1, alpha=alpha)
105
+
106
+ if legend:
107
+ plt.legend(title='Spatial factors', handles=patches, bbox_to_anchor=(1.01, 1.01), loc="upper left")
108
+
109
+ if save_path is not None:
110
+ plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0)