File size: 12,399 Bytes
5238ef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdd9364
5238ef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import argparse
import copy
import os
import time
from tqdm import tqdm

import numpy as np
import PIL.Image
import torch

import clip
from wrapper import (FaceLandmarksDetector, Generator_wrapper, 
                     VGGFeatExtractor, e4eEncoder, PivotTuning)
from projector import project 

class Manipulator():
    """Manipulator for style editing

    in paper, use 100 image pairs to estimate the mean for alpha(magnitude of the perturbation) [-5, 5]

    *** Args ***
    G : Genertor wrapper for synthesis styles
    device : torch.device
    lst_alpha : magnitude of the perturbation
    num_images : num images to process

    *** Attributes ***
    S :  List[dict(str, torch.Tensor)] # length 2,000
    styles : List[dict(str, torch.Tensor)] # length of num_images
                (num_images, style)
    lst_alpha : List[int]
    boundary : (num_images, len_alpha)
    edited_styles : List[styles]
    edited_images : List[(num_images, 3, 1024, 1024)]
    """
    def __init__(
        self, 
        G, 
        device, 
        lst_alpha=[0], 
        num_images=1, 
        start_ind=0, 
        face_preprocess=True,
        dataset_name=''
    ):
        """Initialize 
        - use pre-saved generated latent/style from random Z
        - to use projection, used method "set_real_img_projection"
        """
        assert start_ind + num_images < 2000
        self.W = torch.load(f'tensor/W{dataset_name}.pt')
        self.S = torch.load(f'tensor/S{dataset_name}.pt')
        self.S_mean = torch.load(f'tensor/S_mean{dataset_name}.pt')
        self.S_std = torch.load(f'tensor/S_std{dataset_name}.pt')

        self.S = {layer: self.S[layer].to(device) for layer in G.style_layers}
        self.styles = {layer: self.S[layer][start_ind:start_ind+num_images] for layer in G.style_layers}
        self.latent = self.W[start_ind:start_ind+num_images]
        self.latent = self.latent.to(device)
        del self.W
        del self.S

        # S_mean, S_std for extracting global style direction
        self.S_mean = {layer: self.S_mean[layer].to(device) for layer in G.style_layers}
        self.S_std = {layer: self.S_std[layer].to(device) for layer in G.style_layers}

        # setting
        self.face_preprocess = face_preprocess
        if face_preprocess:
            self.landmarks_detector = FaceLandmarksDetector()
        self.vgg16 = VGGFeatExtractor(device).module
        self.W_projector_steps = 200
        self.G = G
        self.device = device
        self.num_images = num_images
        self.lst_alpha = lst_alpha
        self.manipulate_layers = [layer for layer in G.style_layers if 'torgb' not in layer] 

    def set_alpha(self, lst_alpha):
        """Setter for alpha
        """
        self.lst_alpha = lst_alpha

    def set_real_img_projection(self, img, inv_mode='w', pti_mode=None):
        """Set real img instead of pre-saved styles
        Args : 
        - img : img directory or img file path to manipulate
            - face aligned if self.face_preprocess == True
            - set self.num_images
        - inv_mode : inversion mode, setting self.latent, self.styles
            - w : use W projector (projector.project)
            - w+ : use e4e encoder (wrapper.e4eEncoder)
        - pti_mode : pivot tuning inversion mode (wrapper.PivotTuning)
            - None
            - w : W latent pivot tuning
            - s : S style pivot tuning
        """
        assert inv_mode in ['w', 'w+']
        assert pti_mode in [None, 'w', 's']
        allowed_extensions = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG']

        # img directory input
        if os.path.isdir(img):
            imgpaths = sorted(os.listdir(img))
            imgpaths = [os.path.join(img, imgpath) 
                        for imgpath in imgpaths 
                        if imgpath.split('.')[-1] in allowed_extensions]
        # img file path input
        else:
            imgpaths = [img]

        self.num_images = len(imgpaths)
        if inv_mode == 'w':
            targets = list()
            target_pils = list()
            for imgpath in imgpaths:
                if self.face_preprocess:
                    target_pil = self.landmarks_detector(imgpath)
                else:
                    target_pil = PIL.Image.open(imgpath).convert('RGB')
                target_pils.append(target_pil)
                w, h = target_pil.size
                s = min(w, h)
                target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
                target_pil = target_pil.resize((self.G.G.img_resolution, self.G.G.img_resolution), 
                                                PIL.Image.LANCZOS)
                target_uint8 = np.array(target_pil, dtype=np.uint8)
                targets.append(torch.Tensor(target_uint8.transpose([2,0,1])).to(self.device))

            self.latent = list()
            for target in tqdm(targets, total=len(targets)):
                projected_w_steps = project(
                    self.G.G,
                    self.vgg16,
                    target=target,
                    num_steps=self.W_projector_steps, # TODO get projector steps from configs
                    device=self.device,
                    verbose=False,
                )
                self.latent.append(projected_w_steps[-1])
            self.latent = torch.stack(self.latent)
            self.styles = self.G.mapping_stylespace(self.latent)

        else: # inv_mode == 'w+'
            # use e4e encoder
            target_pils = list()
            for imgpath in imgpaths:
                if self.face_preprocess:
                    target_pil = self.landmarks_detector(imgpath)
                else:
                    target_pil = PIL.Image.open(imgpath).convert('RGB')
                target_pils.append(target_pil)

            self.encoder = e4eEncoder(self.device)
            self.latent = self.encoder(target_pils)
            self.styles = self.G.mapping_stylespace(self.latent)

        if pti_mode is not None: # w or s
            # pivot tuning inversion 
            pti = PivotTuning(self.device, self.G.G, mode=pti_mode)
            new_G = pti(self.latent, target_pils)
            self.G.G = new_G

    def manipulate(self, delta_s):
        """Edit style by given delta_style
        - use perturbation (delta s) * (alpha) as a boundary
        """
        styles = [copy.deepcopy(self.styles) for _ in range(len(self.lst_alpha))]

        for (alpha, style) in zip(self.lst_alpha, styles):
            for layer in self.G.style_layers:
                perturbation = delta_s[layer] * alpha
                style[layer] += perturbation
        return styles

    def manipulate_one_channel(self, layer, channel_ind:int):
        """Edit style from given layer, channel index
        - use mean value of pre-saved style
        - use perturbation (pre-saved style std) * (alpha) as a boundary
        """
        assert layer in self.G.style_layers
        assert 0 <= channel_ind < self.styles[layer].shape[1]
        boundary = self.S_std[layer][channel_ind].item()
        # apply self.S_mean value for given layer, channel_ind
        for ind in range(self.num_images):
            self.styles[layer][ind][channel_ind] = self.S_mean[layer][channel_ind]
        styles = [copy.deepcopy(self.styles) for _ in range(len(self.lst_alpha))]
        
        perturbation = (torch.Tensor(self.lst_alpha) * boundary).numpy().tolist()
       
        # apply one channel manipulation
        for img_ind in range(self.num_images):
            for edit_ind, delta in enumerate(perturbation):
                styles[edit_ind][layer][img_ind][channel_ind] += delta

        return styles

    def synthesis_from_styles(self, styles, start_ind, end_ind):
        """Synthesis edited styles from styles, lst_alpha
        """
        styles_ = list()
        for style in styles:
            style_ = dict()
            for layer in self.G.style_layers:
                style_[layer] = style[layer][start_ind:end_ind].to(self.device)
            styles_.append(style_)
        print("synthesis_from_styles", type(style_))
        imgs = [self.G.synthesis_from_stylespace(self.latent[start_ind:end_ind], style_).cpu() 
                for style_ in styles_]
        return imgs


def extract_global_direction(G, device, lst_alpha, num_images, dataset_name=''):
    """Extract global style direction in 100 images
    """
    assert len(lst_alpha) == 2
    model, preprocess = clip.load("ViT-B/32", device=device)
    
    # lindex in original tf version
    manipulate_layers = [layer for layer in G.style_layers if 'torgb' not in layer] 

    # total channel: 6048 (1024 resolution)
    resolution = G.G.img_resolution
    latent = torch.randn([1,G.to_w_idx[f'G.synthesis.b{resolution}.torgb.affine']+1,512]).to(device) # 1024 -> 18, 512 -> 16, 256 -> 14
    style = G.mapping_stylespace(latent)
    cnt = 0
    for layer in manipulate_layers:
        cnt += style[layer].shape[1]
    del latent
    del style

    # 1024 -> 6048 channels, 256 -> 4928 channels
    print(f"total channels to manipulate: {cnt}")
    
    manipulator = Manipulator(G, device, lst_alpha, num_images, face_preprocess=False, dataset_name=dataset_name)

    all_feats = list()

    for layer in manipulate_layers:
        print(f'\nStyle manipulation in layer "{layer}"')
        channel_num = manipulator.styles[layer].shape[1]

        for channel_ind in tqdm(range(channel_num), total=channel_num):
            styles = manipulator.manipulate_one_channel(layer, channel_ind)
            # 2 * 100 images
            batchsize = 10
            nbatch = int(100 / batchsize)
            feats = list()
            for img_ind in range(0, nbatch): # batch size 10 * 2
                start = img_ind*nbatch
                end = img_ind*nbatch + batchsize
                synth_imgs = manipulator.synthesis_from_styles(styles, start, end)
                synth_imgs = [(synth_img.permute(0,2,3,1)*127.5+128).clamp(0,255).to(torch.uint8).numpy()
                            for synth_img in synth_imgs]
                imgs = list()
                for i in range(batchsize):
                    img0 = PIL.Image.fromarray(synth_imgs[0][i])
                    img1 = PIL.Image.fromarray(synth_imgs[1][i])
                    imgs.append(preprocess(img0).unsqueeze(0).to(device))
                    imgs.append(preprocess(img1).unsqueeze(0).to(device))
                with torch.no_grad():
                    feat = model.encode_image(torch.cat(imgs))
                feats.append(feat)
            all_feats.append(torch.cat(feats).view([-1, 2, 512]).cpu())

    all_feats = torch.stack(all_feats).numpy()

    fs = all_feats
    fs1=fs/np.linalg.norm(fs,axis=-1)[:,:,:,None]
    fs2=fs1[:,:,1,:]-fs1[:,:,0,:] # 5*sigma - (-5)*sigma
    fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None]
    fs3=fs3.mean(axis=1)
    fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None]

    np.save(f'tensor/fs3{dataset_name}.npy', fs3) # global style direction


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('runtype', type=str, default='test')
    parser.add_argument('--ckpt', type=str, default='pretrained/ffhq.pkl')
    parser.add_argument('--face_preprocess', type=bool, default=True)
    parser.add_argument('--dataset_name', type=str, default='')
    args = parser.parse_args()
    
    runtype = args.runtype
    assert runtype in ['test', 'extract'] 

    device = torch.device('cuda:0')
    ckpt = args.ckpt
    G = Generator_wrapper(ckpt, device)

    face_preprocess = args.face_preprocess
    dataset_name = args.dataset_name

    if runtype == 'test': # test manipulator
        num_images = 100
        lst_alpha = [-5, 0, 5]
        layer = G.style_layers[6]
        channel_ind = 501
        manipulator = Manipulator(G, device, lst_alpha, num_images, face_preprocess=face_preprocess, dataset_name=dataset_name)
        styles = manipulator.manipulate_one_channel(layer, channel_ind)
        start_ind, end_ind= 0, 10
        imgs = manipulator.synthesis_from_styles(styles, start_ind, end_ind)
        print(len(imgs), imgs[0].shape)

    elif runtype == 'extract': # extract global style direction from "tensor/S.pt"
        num_images = 100
        lst_alpha = [-5, 5]
        extract_global_direction(G, device, lst_alpha, num_images, dataset_name=dataset_name)