Spaces:
Runtime error
Runtime error
File size: 12,391 Bytes
5238ef9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 |
import argparse
import copy
import os
import time
from tqdm import tqdm
import numpy as np
import PIL.Image
import torch
import clip
from wrapper import (FaceLandmarksDetector, Generator_wrapper,
VGGFeatExtractor, e4eEncoder, PivotTuning)
from projector import project
class Manipulator():
"""Manipulator for style editing
in paper, use 100 image pairs to estimate the mean for alpha(magnitude of the perturbation) [-5, 5]
*** Args ***
G : Genertor wrapper for synthesis styles
device : torch.device
lst_alpha : magnitude of the perturbation
num_images : num images to process
*** Attributes ***
S : List[dict(str, torch.Tensor)] # length 2,000
styles : List[dict(str, torch.Tensor)] # length of num_images
(num_images, style)
lst_alpha : List[int]
boundary : (num_images, len_alpha)
edited_styles : List[styles]
edited_images : List[(num_images, 3, 1024, 1024)]
"""
def __init__(
self,
G,
device,
lst_alpha=[0],
num_images=1,
start_ind=0,
face_preprocess=True,
dataset_name=''
):
"""Initialize
- use pre-saved generated latent/style from random Z
- to use projection, used method "set_real_img_projection"
"""
assert start_ind + num_images < 2000
self.W = torch.load(f'tensor/W{dataset_name}.pt')
self.S = torch.load(f'tensor/S{dataset_name}.pt')
self.S_mean = torch.load(f'tensor/S_mean{dataset_name}.pt')
self.S_std = torch.load(f'tensor/S_std{dataset_name}.pt')
self.S = {layer: self.S[layer].to(device) for layer in G.style_layers}
self.styles = {layer: self.S[layer][start_ind:start_ind+num_images] for layer in G.style_layers}
self.latent = self.W[start_ind:start_ind+num_images]
self.latent = self.latent.to(device)
del self.W
del self.S
# S_mean, S_std for extracting global style direction
self.S_mean = {layer: self.S_mean[layer].to(device) for layer in G.style_layers}
self.S_std = {layer: self.S_std[layer].to(device) for layer in G.style_layers}
# setting
self.face_preprocess = face_preprocess
if face_preprocess:
self.landmarks_detector = FaceLandmarksDetector()
self.vgg16 = VGGFeatExtractor(device).module
self.W_projector_steps = 200
self.G = G
self.device = device
self.num_images = num_images
self.lst_alpha = lst_alpha
self.manipulate_layers = [layer for layer in G.style_layers if 'torgb' not in layer]
def set_alpha(self, lst_alpha):
"""Setter for alpha
"""
self.lst_alpha = lst_alpha
def set_real_img_projection(self, img, inv_mode='w', pti_mode=None):
"""Set real img instead of pre-saved styles
Args :
- img : img directory or img file path to manipulate
- face aligned if self.face_preprocess == True
- set self.num_images
- inv_mode : inversion mode, setting self.latent, self.styles
- w : use W projector (projector.project)
- w+ : use e4e encoder (wrapper.e4eEncoder)
- pti_mode : pivot tuning inversion mode (wrapper.PivotTuning)
- None
- w : W latent pivot tuning
- s : S style pivot tuning
"""
assert inv_mode in ['w', 'w+']
assert pti_mode in [None, 'w', 's']
allowed_extensions = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG']
# img directory input
if os.path.isdir(img):
imgpaths = sorted(os.listdir(img))
imgpaths = [os.path.join(img, imgpath)
for imgpath in imgpaths
if imgpath.split('.')[-1] in allowed_extensions]
# img file path input
else:
imgpaths = [img]
self.num_images = len(imgpaths)
if inv_mode == 'w':
targets = list()
target_pils = list()
for imgpath in imgpaths:
if self.face_preprocess:
target_pil = self.landmarks_detector(imgpath)
else:
target_pil = PIL.Image.open(imgpath).convert('RGB')
target_pils.append(target_pil)
w, h = target_pil.size
s = min(w, h)
target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
target_pil = target_pil.resize((self.G.G.img_resolution, self.G.G.img_resolution),
PIL.Image.LANCZOS)
target_uint8 = np.array(target_pil, dtype=np.uint8)
targets.append(torch.Tensor(target_uint8.transpose([2,0,1])).to(self.device))
self.latent = list()
for target in tqdm(targets, total=len(targets)):
projected_w_steps = project(
self.G.G,
self.vgg16,
target=target,
num_steps=self.W_projector_steps, # TODO get projector steps from configs
device=self.device,
verbose=False,
)
self.latent.append(projected_w_steps[-1])
self.latent = torch.stack(self.latent)
self.styles = self.G.mapping_stylespace(self.latent)
else: # inv_mode == 'w+'
# use e4e encoder
target_pils = list()
for imgpath in imgpaths:
if self.face_preprocess:
target_pil = self.landmarks_detector(imgpath)
else:
target_pil = PIL.Image.open(imgpath).convert('RGB')
target_pils.append(target_pil)
self.encoder = e4eEncoder(self.device)
self.latent = self.encoder(target_pils)
self.styles = self.G.mapping_stylespace(self.latent)
if pti_mode is not None: # w or s
# pivot tuning inversion
pti = PivotTuning(self.device, self.G.G, mode=pti_mode)
new_G = pti(self.latent, target_pils)
self.G.G = new_G
def manipulate(self, delta_s):
"""Edit style by given delta_style
- use perturbation (delta s) * (alpha) as a boundary
"""
styles = [copy.deepcopy(self.styles) for _ in range(len(self.lst_alpha))]
for (alpha, style) in zip(self.lst_alpha, styles):
for layer in self.G.style_layers:
perturbation = delta_s[layer] * alpha
style[layer] += perturbation
return styles
def manipulate_one_channel(self, layer, channel_ind:int):
"""Edit style from given layer, channel index
- use mean value of pre-saved style
- use perturbation (pre-saved style std) * (alpha) as a boundary
"""
assert layer in self.G.style_layers
assert 0 <= channel_ind < self.styles[layer].shape[1]
boundary = self.S_std[layer][channel_ind].item()
# apply self.S_mean value for given layer, channel_ind
for ind in range(self.num_images):
self.styles[layer][ind][channel_ind] = self.S_mean[layer][channel_ind]
styles = [copy.deepcopy(self.styles) for _ in range(len(self.lst_alpha))]
perturbation = (torch.Tensor(self.lst_alpha) * boundary).numpy().tolist()
# apply one channel manipulation
for img_ind in range(self.num_images):
for edit_ind, delta in enumerate(perturbation):
styles[edit_ind][layer][img_ind][channel_ind] += delta
return styles
def synthesis_from_styles(self, styles, start_ind, end_ind):
"""Synthesis edited styles from styles, lst_alpha
"""
styles_ = list()
for style in styles:
style_ = dict()
for layer in self.G.style_layers:
style_[layer] = style[layer][start_ind:end_ind].to(self.device)
styles_.append(style_)
print("synthesis_from_styles", type(style_))
imgs = [self.G.synthesis_from_stylespace(self.latent[start_ind:end_ind], style_).cpu()
for style_ in styles_]
return imgs
def extract_global_direction(G, device, lst_alpha, num_images, dataset_name=''):
"""Extract global style direction in 100 images
"""
assert len(lst_alpha) == 2
model, preprocess = clip.load("ViT-B/32", device=device)
# lindex in original tf version
manipulate_layers = [layer for layer in G.style_layers if 'torgb' not in layer]
# total channel: 6048 (1024 resolution)
resolution = G.G.img_resolution
latent = torch.randn([1,G.to_w_idx[f'G.synthesis.b{resolution}.torgb.affine']+1,512]).to(device) # 1024 -> 18, 512 -> 16, 256 -> 14
style = G.mapping_stylespace(latent)
cnt = 0
for layer in manipulate_layers:
cnt += style[layer].shape[1]
del latent
del style
# 1024 -> 6048 channels, 256 -> 4928 channels
print(f"total channels to manipulate: {cnt}")
manipulator = Manipulator(G, device, lst_alpha, num_images, face_preprocess=False, dataset_name=dataset_name)
all_feats = list()
for layer in manipulate_layers:
print(f'\nStyle manipulation in layer "{layer}"')
channel_num = manipulator.styles[layer].shape[1]
for channel_ind in tqdm(range(channel_num), total=channel_num):
styles = manipulator.manipulate_one_channel(layer, channel_ind)
# 2 * 100 images
batchsize = 10
nbatch = int(100 / batchsize)
feats = list()
for img_ind in range(0, nbatch): # batch size 10 * 2
start = img_ind*nbatch
end = img_ind*nbatch + batchsize
synth_imgs = manipulator.synthesis_from_styles(styles, start, end)
synth_imgs = [(synth_img.permute(0,2,3,1)*127.5+128).clamp(0,255).to(torch.uint8).numpy()
for synth_img in synth_imgs]
imgs = list()
for i in range(batchsize):
img0 = PIL.Image.fromarray(synth_imgs[0][i])
img1 = PIL.Image.fromarray(synth_imgs[1][i])
imgs.append(preprocess(img0).unsqueeze(0).to(device))
imgs.append(preprocess(img1).unsqueeze(0).to(device))
with torch.no_grad():
feat = model.encode_image(torch.cat(imgs))
feats.append(feat)
all_feats.append(torch.cat(feats).view([-1, 2, 512]).cpu())
all_feats = torch.stack(all_feats).numpy()
fs = all_feats
fs1=fs/np.linalg.norm(fs,axis=-1)[:,:,:,None]
fs2=fs1[:,:,1,:]-fs1[:,:,0,:] # 5*sigma - (-5)*sigma
fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None]
fs3=fs3.mean(axis=1)
fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None]
np.save(f'tensor/fs3{dataset_name}.npy', fs3) # global style direction
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('runtype', type=str, default='test')
parser.add_argument('--ckpt', type=str, default='pretrained/ffhq.pkl')
parser.add_argument('--face_preprocess', type=bool, default=True)
parser.add_argument('--dataset_name', type=str, default='')
args = parser.parse_args()
runtype = args.runtype
assert runtype in ['test', 'extract']
device = torch.device('cuda:0')
ckpt = args.ckpt
G = Generator(ckpt, device)
face_preprocess = args.face_preprocess
dataset_name = args.dataset_name
if runtype == 'test': # test manipulator
num_images = 100
lst_alpha = [-5, 0, 5]
layer = G.style_layers[6]
channel_ind = 501
manipulator = Manipulator(G, device, lst_alpha, num_images, face_preprocess=face_preprocess, dataset_name=dataset_name)
styles = manipulator.manipulate_one_channel(layer, channel_ind)
start_ind, end_ind= 0, 10
imgs = manipulator.synthesis_from_styles(styles, start_ind, end_ind)
print(len(imgs), imgs[0].shape)
elif runtype == 'extract': # extract global style direction from "tensor/S.pt"
num_images = 100
lst_alpha = [-5, 5]
extract_global_direction(G, device, lst_alpha, num_images, dataset_name=dataset_name)
|