File size: 3,925 Bytes
4347fc3
 
 
973a4da
 
708c2f6
973a4da
 
 
e79558d
70dfa79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4347fc3
 
60af39a
4347fc3
 
60af39a
4347fc3
944e312
4098d00
944e312
e79558d
 
 
 
 
 
 
 
 
 
 
 
 
4347fc3
3f788ef
944e312
 
 
 
 
 
 
 
 
 
 
 
 
60af39a
944e312
 
60af39a
944e312
1efb52e
944e312
 
70dfa79
 
 
 
 
 
 
 
 
 
944e312
973a4da
3b15df8
 
 
 
3f788ef
 
c3c13cc
3f788ef
973a4da
3f788ef
973a4da
3b15df8
a16f7a6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import numpy as np
import torch
import PIL
from tqdm import tqdm
import random
from PIL import Image



def generate_composite_images(model, z, decision_boundaries, lambdas, latent_space='W', negative_colors=None):
    """
    The regenerate_images function takes a model, z, and decision_boundary as input.  It then
    constructs an inverse rotation/translation matrix and passes it to the generator.  The generator
    expects this matrix as an inverse to avoid potentially failing numerical operations in the network.
    The function then generates images using G(z_0, label) where z_0 is a linear combination of z and the decision boundary.
    
    :param model: Pass in the model to be used for image generation
    :param z: Generate the starting point of the line
    :param decision_boundary: Generate images along the direction of the decision boundary
    :param min_epsilon: Set the minimum value of lambda
    :param max_epsilon: Set the maximum distance from the original image to generate
    :param count: Determine the number of images that are generated
    :return: A list of images and a list of lambdas
    :doc-author: Trelent
    """
    device = torch.device('cpu')
    G = model.to(device) # type: ignore
    
    # Labels.
    label = torch.zeros([1, G.c_dim], device=device)

    z = torch.from_numpy(z.copy()).to(device)
    repetitions = 16 
    z_0 = z
    
    if negative_colors:
        for decision_boundary, lmbd, neg_boundary in zip(decision_boundaries, lambdas, negative_colors):
            decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
            if neg_boundary != 'None':
                neg_boundary = torch.from_numpy(neg_boundary.copy()).to(device)
            
                z_0 = z_0 + int(lmbd) * (decision_boundary - (neg_boundary.T * decision_boundary) * neg_boundary)
            else:
                z_0 = z_0 + int(lmbd) * decision_boundary
    else:            
        for decision_boundary, lmbd in zip(decision_boundaries, lambdas):
            decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
            z_0 = z_0 + int(lmbd) * decision_boundary
        
        
    if latent_space == 'Z':
        W_0 = G.mapping(z_0, label, truncation_psi=1).to(torch.float32)
        # W = G.mapping(z, label, truncation_psi=1).to(torch.float32)
    else:
        W_0 = z_0.expand((repetitions, -1)).unsqueeze(0).to(torch.float32)
        # W = z.expand((repetitions, -1)).unsqueeze(0).to(torch.float32)
        
    # if layers:
    #     W_f = torch.empty_like(W).copy_(W).to(torch.float32)
    #     W_f[:, layers, :] = W_0[:, layers, :]
    #     img = G.synthesis(W_f, noise_mode='const')
    # else:
    img = G.synthesis(W_0, noise_mode='const')
                
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
            
    return img
    
  
def generate_original_image(z, model, latent_space='W'):
    """
    The generate_original_image function takes in a latent vector and the model,
    and returns an image generated from that latent vector.
    
    
    :param z: Generate the image
    :param model: Generate the image
    :return: A pil image
    :doc-author: Trelent
    """
    repetitions = 16
    
    device = torch.device('cpu')
    G = model.to(device) # type: ignore
    # Labels.
    label = torch.zeros([1, G.c_dim], device=device)
    if latent_space == 'Z':
        z = torch.from_numpy(z.copy()).to(device)
        img = G(z, label, truncation_psi=1, noise_mode='const')
    else:
        W = torch.from_numpy(np.repeat(z, repetitions, axis=0).reshape(1, repetitions, z.shape[1]).copy()).to(device)
        img = G.synthesis(W, noise_mode='const')
    
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')