latent-space-theories / backend /disentangle_concepts.py
ludusc's picture
third page
e79558d
raw
history blame
3.93 kB
import numpy as np
import torch
import PIL
from tqdm import tqdm
import random
from PIL import Image
def generate_composite_images(model, z, decision_boundaries, lambdas, latent_space='W', negative_colors=None):
"""
The regenerate_images function takes a model, z, and decision_boundary as input. It then
constructs an inverse rotation/translation matrix and passes it to the generator. The generator
expects this matrix as an inverse to avoid potentially failing numerical operations in the network.
The function then generates images using G(z_0, label) where z_0 is a linear combination of z and the decision boundary.
:param model: Pass in the model to be used for image generation
:param z: Generate the starting point of the line
:param decision_boundary: Generate images along the direction of the decision boundary
:param min_epsilon: Set the minimum value of lambda
:param max_epsilon: Set the maximum distance from the original image to generate
:param count: Determine the number of images that are generated
:return: A list of images and a list of lambdas
:doc-author: Trelent
"""
device = torch.device('cpu')
G = model.to(device) # type: ignore
# Labels.
label = torch.zeros([1, G.c_dim], device=device)
z = torch.from_numpy(z.copy()).to(device)
repetitions = 16
z_0 = z
if negative_colors:
for decision_boundary, lmbd, neg_boundary in zip(decision_boundaries, lambdas, negative_colors):
decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
if neg_boundary != 'None':
neg_boundary = torch.from_numpy(neg_boundary.copy()).to(device)
z_0 = z_0 + int(lmbd) * (decision_boundary - (neg_boundary.T * decision_boundary) * neg_boundary)
else:
z_0 = z_0 + int(lmbd) * decision_boundary
else:
for decision_boundary, lmbd in zip(decision_boundaries, lambdas):
decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
z_0 = z_0 + int(lmbd) * decision_boundary
if latent_space == 'Z':
W_0 = G.mapping(z_0, label, truncation_psi=1).to(torch.float32)
# W = G.mapping(z, label, truncation_psi=1).to(torch.float32)
else:
W_0 = z_0.expand((repetitions, -1)).unsqueeze(0).to(torch.float32)
# W = z.expand((repetitions, -1)).unsqueeze(0).to(torch.float32)
# if layers:
# W_f = torch.empty_like(W).copy_(W).to(torch.float32)
# W_f[:, layers, :] = W_0[:, layers, :]
# img = G.synthesis(W_f, noise_mode='const')
# else:
img = G.synthesis(W_0, noise_mode='const')
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
return img
def generate_original_image(z, model, latent_space='W'):
"""
The generate_original_image function takes in a latent vector and the model,
and returns an image generated from that latent vector.
:param z: Generate the image
:param model: Generate the image
:return: A pil image
:doc-author: Trelent
"""
repetitions = 16
device = torch.device('cpu')
G = model.to(device) # type: ignore
# Labels.
label = torch.zeros([1, G.c_dim], device=device)
if latent_space == 'Z':
z = torch.from_numpy(z.copy()).to(device)
img = G(z, label, truncation_psi=1, noise_mode='const')
else:
W = torch.from_numpy(np.repeat(z, repetitions, axis=0).reshape(1, repetitions, z.shape[1]).copy()).to(device)
img = G.synthesis(W, noise_mode='const')
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')