PandA / model.py
james-oldfield's picture
Upload 18 files
ceb80dd
raw
history blame contribute delete
No virus
23 kB
from networks.load_generator import load_generator
from networks.genforce.utils.visualizer import postprocess_image as postprocess
from networks.biggan import one_hot_from_names, truncated_noise_sample
from networks.stylegan3.load_stylegan3 import make_transform
from matplotlib import pyplot as plt
from utils import plot_masks, plot_colours, mapRange
import torch
import numpy as np
from PIL import Image
import tensorly as tl
tl.set_backend('pytorch')
class Model():
def __init__(self, model_name, t=0, layer=5, trunc_psi=1.0, trunc_layers=18, device='cuda', biggan_classes=['fox']):
"""
Instantiate the model for decomposition and/or local image editing.
Parameters
----------
model_name : string
Name of architecture and dataset--one of the items in ./networks/genforce/models/model_zoo.py.
t : int
Random seed for the generator (to generate a sample image).
layer : int
Intermediate layer at which to perform the decomposition.
trunc_psi : float
Truncation value in [0, 1].
trunc_layers : int
Number of layers at which to apply truncation.
device : string
Device to store the tensors on.
biggan_classes : list
List of strings specifying imagenet classes of interest (e.g. ['alp', 'breakwater']).
"""
self.gan_type = model_name.split('_')[0]
self.model_name = model_name
self.randomize_noise = False
self.device = device
self.biggan_classes = biggan_classes
self.layer = layer # layer to decompose
self.start = 0 if 'stylegan2' in self.gan_type else 2
self.trunc_psi = trunc_psi
self.trunc_layers = trunc_layers
self.generator = load_generator(model_name, device)
noise = torch.Tensor(np.random.randn(1, self.generator.z_space_dim)).to(self.device)
z, image = self.sample(noise, layer=layer, trunc_psi=trunc_psi, trunc_layers=trunc_layers, verbose=True)
self.c = z.shape[1]
self.s = z.shape[2]
self.image = image
def HOSVD(self, batch_size=10, n_iters=100):
"""
Initialises the appearance basis A. In particular, computes the left-singular vectors of the channel mode's scatter matrix.
Note: total samples used is batch_size * n_iters
Parameters
----------
batch_size : int
Number of activations to sample in a single go.
n_iters : int
Number of times to sample `batch_size`-many activations.
"""
np.random.seed(0)
torch.manual_seed(0)
with torch.no_grad():
Z = torch.zeros((batch_size * n_iters, self.c, self.s, self.s), device=self.device)
# note: perform in loops to have a larger effective batch size
print('Starting loops...')
for i in range(n_iters):
np.random.seed(i)
torch.manual_seed(i)
noise = torch.Tensor(np.random.randn(batch_size, self.generator.z_space_dim)).to(self.device)
z, _ = self.sample(noise, layer=self.layer, partial=True)
Z[(batch_size * i):(batch_size * (i + 1))] = z
Z = Z.view([-1, self.c, self.s**2])
print(f'Generated {batch_size * n_iters} gan samples...')
scat = 0
for _, x in enumerate(Z):
# mode-3 unfolding in the paper, but in PyTorch channel mode is first.
m_unfold = tl.unfold(x, 0)
scat += m_unfold @ m_unfold.T
self.Uc_init, _, _ = np.linalg.svd((scat / len(Z)).cpu().numpy())
self.Uc_init = torch.Tensor(self.Uc_init).to(self.device)
print('... HOSVD done')
def decompose(self, ranks=[512, 8], lr=1e-8, batch_size=1, its=10000, log_modulo=1000, hosvd_init=True, stochastic=True, n_iters=1, verbose=True):
"""
Performs the decomposition in the paper. In particular, Algorithm 1.,
either with a non-fixed batch of samples (stochastic=True), or descends the full gradients.
Parameters
----------
ranks : list
List of integers specifying the R_C and R_S, the ranks--i.e. number of parts and appearances respectively.
lr : float
Learning rate the projected gradient descent.
batch_size : int
Number of samples in each batch.
its : int
Total number of iterations.
log_modulo : int
Parameter used to control how often "training" information is displayed.
hosvd_init : bool
Initialise appearance factors from HOSVD? (else from random normal).
stochastic : bool
Sample the batch again each iteration? Else descent full gradients
n_iters : int
Number of `batch_size`-many samples to take (for full gradient).
The total activations are sampled in batches in a loop to enable it to fit in memory.
verbose : bool
Prints extra information.
"""
self.ranks = ranks
np.random.seed(0)
torch.manual_seed(0)
#######################
# init from HOSVD, else random normal
Uc = self.Uc_init[:, :ranks[0]].detach().clone().to(self.device) if hosvd_init else torch.randn(self.Uc_init.shape[0], ranks[0]).detach().clone().to(self.device) * 0.01
Us = torch.Tensor(np.random.uniform(0, 0.01, size=[self.s**2, ranks[1]])).to(self.device)
#######################
print(f'Uc shape: {Uc.shape}, Us shape: {Us.shape}')
with torch.no_grad():
zeros = torch.zeros_like(Us, device=self.device)
Us = torch.maximum(Us, zeros)
# use a fixed batch (i.e. descend the full gradient)
if not stochastic:
Z = torch.zeros((batch_size * n_iters, self.c, self.s, self.s), device=self.device)
# note: perform in loops to have a larger effective batch size
print(f'Starting loops, total Z shape: {Z.shape}...')
for i in range(n_iters):
np.random.seed(i)
torch.manual_seed(i)
noise = torch.Tensor(np.random.randn(batch_size, self.generator.z_space_dim)).to(self.device)
z, _ = self.sample(noise, layer=self.layer, partial=True)
Z[(batch_size * i):(batch_size * (i + 1))] = z
for t in range(its):
np.random.seed(t)
torch.manual_seed(t)
# resample the batch, if stochastic
if stochastic:
noise = torch.Tensor(np.random.randn(batch_size, self.generator.z_space_dim)).to(self.device)
Z, _ = self.sample(noise, layer=self.layer, partial=True)
if verbose:
# reconstruct (for visualisation)
coords = tl.tenalg.multi_mode_dot(Z.view(-1, self.c, self.s**2).float(), [Uc.T, Us.T], transpose=False, modes=[1, 2])
Z_rec = tl.tenalg.multi_mode_dot(coords, [Uc, Us], transpose=False, modes=[1, 2])
self.rec_loss = torch.mean(torch.norm(Z.view(-1, self.c, self.s**2).float() - Z_rec, p='fro', dim=[1, 2]) ** 2)
# Update S
z = Z.view(-1, self.c, self.s**2).float()
Us_g = -4 * (torch.transpose(z,1,2)@Uc@Uc.T@z@Us) + \
2 * (Us@Us.T@torch.transpose(z,1,2)@Uc@Uc.T@Uc@Uc.T@z@Us + torch.transpose(z,1,2)@Uc@Uc.T@Uc@Uc.T@z@Us@Us.T@Us)
Us_g = torch.sum(Us_g, 0)
Us = Us - lr * Us_g
# --- projection step ---a
Us = torch.maximum(Us, zeros)
# Update C
Uc_g = -4 * (z@Us@Us.T@torch.transpose(z,1,2)@Uc) + \
2 * (Uc@Uc.T@z@Us@Us.T@Us@Us.T@torch.transpose(z,1,2)@Uc + z@Us@Us.T@Us@Us.T@torch.transpose(z,1,2)@Uc@Uc.T@Uc)
Uc_g = torch.sum(Uc_g, 0)
Uc = Uc - lr * Uc_g
self.Us = Us
self.Uc = Uc
if t % log_modulo == 0 and verbose:
print(f'ITERATION: {t}')
z, x = self.sample(noise, layer=self.layer, partial=False)
# here we display the learnt parts factors and also overlay them over the images to visualise.
plot_masks(Us.T, r=min(ranks[-1], 32), s=self.s)
plt.show()
plot_colours(x, Us.T, r=ranks[-1], s=self.s, seed=-1)
plt.show()
def decompose_autograd(self, ranks=[512, 8], lr=1e-8, batch_size=1, its=10000, log_modulo=1000, verbose=True, hosvd_init=True):
"""
Performs the same decomposition in the paper, only uses autograd with Adam optimizer (and projected gradient descent).
Parameters
----------
ranks : list
List of integers specifying the R_C and R_S, the ranks--i.e. number of parts and appearances respectively.
lr : float
Learning rate the projected gradient descent.
batch_size : int
Number of samples in each batch.
its : int
Total number of iterations.
log_modulo : int
Parameter used to control how often "training" information is displayed.
hosvd_init : bool
Initialise appearance factors from HOSVD? (else from random normal).
verbose : bool
Prints extra information.
"""
self.ranks = ranks
np.random.seed(0)
torch.manual_seed(0)
#######################
# init from HOSVD, else random normal
Uc = torch.nn.Parameter(self.Uc_init[:, :ranks[0]].detach().clone().to(self.device), requires_grad=True) \
if hosvd_init else torch.nn.Parameter(torch.randn(self.Uc_init.shape[0], ranks[0]).detach().clone().to(self.device) * 0.01)
Us = torch.nn.Parameter(torch.Tensor(np.random.uniform(0, 0.01, size=[self.s**2, ranks[1]])).to(self.device), requires_grad=True)
#######################
optimizerS = torch.optim.Adam([Us], lr=lr)
optimizerC = torch.optim.Adam([Uc], lr=lr)
print(f'Uc shape: {Uc.shape}, Us shape: {Us.shape}')
zeros = torch.zeros_like(Us, device=self.device)
for t in range(its):
np.random.seed(t)
torch.manual_seed(t)
noise = torch.Tensor(np.random.randn(batch_size, self.generator.z_space_dim)).to(self.device)
Z, _ = self.sample(noise, layer=self.layer, partial=True)
# Update S
# reconstruct
coords = tl.tenalg.multi_mode_dot(Z.view(-1, self.c, self.s**2).float(), [Uc.T, Us.T], transpose=False, modes=[1, 2])
Z_rec = tl.tenalg.multi_mode_dot(coords, [Uc, Us], transpose=False, modes=[1, 2])
rec_loss = torch.mean(torch.norm(Z.view(-1, self.c, self.s**2).float() - Z_rec, p='fro', dim=[1, 2]) ** 2)
rec_loss.backward(retain_graph=True)
optimizerS.step()
# --- projection step ---
Us.data = torch.maximum(Us.data, zeros)
optimizerS.zero_grad()
optimizerC.zero_grad()
# Update C
# reconstruct with updated Us
coords = tl.tenalg.multi_mode_dot(Z.view(-1, self.c, self.s**2).float(), [Uc.T, Us.T], transpose=False, modes=[1, 2])
Z_rec = tl.tenalg.multi_mode_dot(coords, [Uc, Us], transpose=False, modes=[1, 2])
rec_loss = torch.mean(torch.norm(Z.view(-1, self.c, self.s**2).float() - Z_rec, p='fro', dim=[1, 2]) ** 2)
rec_loss.backward()
optimizerC.step()
optimizerS.zero_grad()
optimizerC.zero_grad()
self.Us = Us
self.Uc = Uc
with torch.no_grad():
if t % log_modulo == 0 and verbose:
print(f'Iteration {t} -- rec {rec_loss}')
noise = torch.Tensor(np.random.randn(batch_size, self.generator.z_space_dim)).to(self.device)
Z, x = self.sample(noise, layer=self.layer, partial=False)
plot_masks(Us.T, r=min(ranks[-1], 32), s=self.s)
plt.show()
plot_colours(x, Us.T, r=ranks[-1], s=self.s, seed=-1)
plt.show()
def refine(self, Z, image, lr=1e-8, its=1000, log_modulo=250, verbose=True):
"""
Performs the "refinement" step described in the paper, for a given sample Z.
Parameters
----------
Z : torch.Tensor
Intermediate activations for target refinement.
image : np.array
Corresponding image for Z (purely for visualisation purposes).
lr : float
Learning rate the projected gradient descent.
its : int
Total number of iterations.
log_modulo : int
Parameter used to control how often "training" information is displayed.
verbose : bool
Prints extra information.
Returns
-------
UsR : torch.Tensor
The refined factors \tilde{P}_i.
"""
np.random.seed(0)
torch.manual_seed(0)
#######################
# init from global spatial factors
UsR = self.Us.clone()
Uc = self.Uc
#######################
zeros = torch.zeros_like(self.Us, device=self.device)
for t in range(its):
with torch.no_grad():
z = Z.view(-1, self.c, self.s**2).float()
# descend refinement term's gradient
UsR_g = -4 * (torch.transpose(z,1,2)@Uc@Uc.T@z@UsR) + \
2 * (UsR@UsR.T@torch.transpose(z,1,2)@Uc@Uc.T@Uc@Uc.T@z@UsR + torch.transpose(z,1,2)@Uc@Uc.T@Uc@Uc.T@z@UsR@UsR.T@UsR)
UsR_g = torch.sum(UsR_g, 0)
# Update S
UsR = UsR - lr * UsR_g
# PGD step
UsR = torch.maximum(UsR, zeros)
if ((t + 1) % log_modulo == 0 and verbose):
print(f'iteration {t}')
plot_masks(UsR.T, s=self.s, r=min(self.ranks[-1], 16))
plt.show()
plot_colours(image, UsR.T, s=self.s, r=self.ranks[-1], seed=-1, alpha=0.9)
plt.show()
return UsR
def edit_at_layer(self, part, appearance, lam, t, Uc, Us, noise=None, b_idx=0):
"""
Performs the "refinement" step described in the paper, for a given sample Z.
Parameters
----------
part : list
List of ints containing the part(s) (column of Us) at which to edit.
appearance : list
List of ints containing the appearance (column of Uc) to apply at the corresponding part(s).
lam : list
List of ints containing the magnitude for each edit.
t : int
Random seed to edit
Uc : np.array
Learnt appearance factors
Us : np.array
Learnt parts factors
noise : np.array
If specified, the target latent code itself to edit (i.e. instead of providing than a random seed number).
b_idx : int
Index of biggan categories to use.
Returns
-------
Z : torch.Tensor
The intermediate activation at layer self.L
image : np.array
The original image for sample `t` or from latent code `noise`.
image2 : np.array
The edited image.
part : np.array
The part used to edit.
"""
with torch.no_grad():
if noise is None:
np.random.seed(t)
torch.manual_seed(t)
noise = torch.Tensor(np.random.randn(1, self.generator.z_space_dim)).to(self.device)
else:
np.random.seed(0)
torch.manual_seed(0)
direc = 0
for i in range(len(appearance)):
a = Uc[:, appearance[i]]
p = torch.sum(Us[:, part[i]], dim=-1).reshape([self.s, self.s])
p = mapRange(p, torch.min(p), torch.max(p), 0.0, 1.0)
# here, we basically form a rank-1 "tensor", to add to the target sample's activations.
# intuitively, the non-zero spatial positions of the part are filled with the appearance vector.
direc += lam[i] * tl.tenalg.outer([a, p])
if self.gan_type in ['stylegan', 'stylegan2']:
noise = self.generator.mapping(noise)['w']
noise_trunc = self.generator.truncation(noise, trunc_psi=self.trunc_psi, trunc_layers=self.trunc_layers)
Z = self.generator.synthesis(noise_trunc, start=self.start, stop=self.layer)['x']
x = self.generator.synthesis(noise_trunc, x=Z, start=self.layer)['image']
x_prime = self.generator.synthesis(noise_trunc, x=Z + direc, start=self.layer)['image']
elif 'pggan' in self.gan_type:
Z = self.generator(noise, start=self.start, stop=self.layer)['x']
x = self.generator(Z, start=self.layer)['image']
x_prime = self.generator(Z + direc, start=self.layer)['image']
elif 'biggan' in self.gan_type:
print(f'Choosing a {self.biggan_classes[b_idx]}')
class_vector = torch.tensor(one_hot_from_names([self.biggan_classes[b_idx]]), device=self.device)
noise_vector = torch.tensor(truncated_noise_sample(truncation=self.trunc_psi, batch_size=1, seed=t), device=self.device)
result = self.generator(noise_vector, class_vector, self.trunc_psi, stop=self.layer)
Z, cond_vector = result['z'], result['cond_vector']
x = self.generator(Z, class_vector, self.trunc_psi, cond_vector=cond_vector, start=self.layer)['z']
x_prime = self.generator(Z + direc, class_vector, self.trunc_psi, cond_vector=cond_vector, start=self.layer)['z']
elif 'stylegan3' in self.gan_type:
label = torch.zeros([1, 0], device=self.device)
Z = self.generator(noise, label, stop=self.layer, truncation_psi=self.trunc_psi, noise_mode='const')
x = self.generator(noise, label, x=Z, start=self.layer, stop=None, truncation_psi=self.trunc_psi, noise_mode='const')
x_prime = self.generator(noise, label, x=Z + direc, start=self.layer, stop=None, truncation_psi=self.trunc_psi, noise_mode='const')
image = np.array(Image.fromarray(postprocess(x.cpu().numpy())[0]).resize((256, 256)))
image2 = np.array(Image.fromarray(postprocess(x_prime.cpu().numpy())[0]).resize((256, 256)))
part = np.array(Image.fromarray(p.detach().cpu().numpy() * 255).convert('RGB').resize((256, 256), Image.NEAREST))
return Z, image, image2, part
def sample(self, noise, layer=5, partial=False, trunc_psi=1.0, trunc_layers=18, verbose=False):
"""
Samples intermediate feature maps and resulting image the desired generator.
Parameters
----------
noise : np.array
(batch_size, z_dim)-dim random standard gaussian noise.
layer : int
Intermediate layer at which to return intermediate features.
partial : bool
Perform full forward pass, and return image too? or just intermediate activations at layer number `layer`?
trunc_psi : float
Truncation value in [0, 1].
trunc_layers : int
Number of layers at which to apply truncation.
biggan_classes : list
List of strings specifying imagenet classes of interest (e.g. ['alp', 'breakwater']).
verbose : bool
Print out additional information?
Returns
-------
Z : torch.Tensor
The intermediate activations of shape [C, H, W].
image : np.array
Output RGB image.
"""
with torch.no_grad():
if self.gan_type in ['stylegan', 'stylegan2']:
noise = self.generator.mapping(noise)['w']
noise_trunc = self.generator.truncation(noise, trunc_psi=trunc_psi, trunc_layers=trunc_layers)
Z = self.generator.synthesis(noise_trunc, start=self.start, stop=layer)['x']
if not partial:
x = self.generator.synthesis(noise_trunc, x=Z, start=layer)['image']
elif 'pggan' in self.gan_type:
Z = self.generator(noise, start=self.start, stop=layer)['x']
if not partial:
x = self.generator(Z, start=layer)['image']
elif 'biggan' in self.gan_type:
if verbose:
print(f'Using BigGAN class names: {", ".join(self.biggan_classes)}')
class_vector = torch.tensor(one_hot_from_names(list(np.random.choice(self.biggan_classes, noise.shape[0])), batch_size=noise.shape[0]), device=self.device)
noise_vector = torch.tensor(truncated_noise_sample(truncation=self.trunc_psi, batch_size=noise.shape[0]), device=self.device)
result = self.generator(noise_vector, class_vector, self.trunc_psi, stop=layer)
Z = result['z']
cond_vector = result['cond_vector']
if not partial:
x = self.generator(Z, class_vector, self.trunc_psi, cond_vector=cond_vector, start=layer)['z']
elif 'stylegan3' in self.gan_type:
label = torch.zeros([noise.shape[0], 0], device=self.device)
if hasattr(self.generator.synthesis, 'input'):
m = np.linalg.inv(make_transform((0,0), 0))
self.generator.synthesis.input.transform.copy_(torch.from_numpy(m))
Z = self.generator(noise, label, x=None, start=0, stop=layer, truncation_psi=trunc_psi, noise_mode='const')
if not partial:
x = self.generator(noise, label, x=Z, start=layer, stop=None, truncation_psi=trunc_psi, noise_mode='const')
if verbose:
print(f'-- Partial Z shape at layer {layer}: {Z.shape}')
if partial:
return Z, None
else:
image = postprocess(x.detach().cpu().numpy())
image = np.array(Image.fromarray(image[0]).resize((256, 256)))
return Z, image
def save(self):
Uc_path = f'./checkpoints/Uc-name_{self.model_name}-layer_{self.layer}-rank_{self.ranks[0]}.npy'
Us_path = f'./checkpoints/Us-name_{self.model_name}-layer_{self.layer}-rank_{self.ranks[1]}.npy'
np.save(Us_path, self.Us.detach().cpu().numpy())
np.save(Uc_path, self.Uc.detach().cpu().numpy())
print(f'Saved factors to {Us_path}, {Uc_path}')