|
|
|
|
|
|
|
import os |
|
import pickle |
|
import numpy as np |
|
from dnnlib import tflib |
|
import tensorflow as tf |
|
|
|
import argparse |
|
|
|
def LoadModel(dataset_name): |
|
|
|
tflib.init_tf() |
|
model_path='./model/' |
|
model_name=dataset_name+'.pkl' |
|
|
|
tmp=os.path.join(model_path,model_name) |
|
with open(tmp, 'rb') as f: |
|
_, _, Gs = pickle.load(f) |
|
return Gs |
|
|
|
def lerp(a,b,t): |
|
return a + (b - a) * t |
|
|
|
|
|
def SelectName(layer_name,suffix): |
|
if suffix==None: |
|
tmp1='add:0' in layer_name |
|
tmp2='shape=(?,' in layer_name |
|
tmp4='G_synthesis_1' in layer_name |
|
tmp= tmp1 and tmp2 and tmp4 |
|
else: |
|
tmp1=('/Conv0_up'+suffix) in layer_name |
|
tmp2=('/Conv1'+suffix) in layer_name |
|
tmp3=('4x4/Conv'+suffix) in layer_name |
|
tmp4='G_synthesis_1' in layer_name |
|
tmp5=('/ToRGB'+suffix) in layer_name |
|
tmp= (tmp1 or tmp2 or tmp3 or tmp5) and tmp4 |
|
return tmp |
|
|
|
|
|
def GetSNames(suffix): |
|
|
|
with tf.Session() as sess: |
|
op = sess.graph.get_operations() |
|
layers=[m.values() for m in op] |
|
|
|
|
|
select_layers=[] |
|
for layer in layers: |
|
layer_name=str(layer) |
|
if SelectName(layer_name,suffix): |
|
select_layers.append(layer[0]) |
|
return select_layers |
|
|
|
def SelectName2(layer_name): |
|
tmp1='mod_bias' in layer_name |
|
tmp2='mod_weight' in layer_name |
|
tmp3='ToRGB' in layer_name |
|
|
|
tmp= (tmp1 or tmp2) and (not tmp3) |
|
return tmp |
|
|
|
def GetKName(Gs): |
|
|
|
layers=[var for name, var in Gs.components.synthesis.vars.items()] |
|
|
|
select_layers=[] |
|
for layer in layers: |
|
layer_name=str(layer) |
|
if SelectName2(layer_name): |
|
select_layers.append(layer) |
|
return select_layers |
|
|
|
def GetCode(Gs,random_state,num_img,num_once,dataset_name): |
|
rnd = np.random.RandomState(random_state) |
|
|
|
truncation_psi=0.7 |
|
truncation_cutoff=8 |
|
|
|
dlatent_avg=Gs.get_var('dlatent_avg') |
|
|
|
dlatents=np.zeros((num_img,512),dtype='float32') |
|
for i in range(int(num_img/num_once)): |
|
src_latents = rnd.randn(num_once, Gs.input_shape[1]) |
|
src_dlatents = Gs.components.mapping.run(src_latents, None) |
|
|
|
|
|
if truncation_psi is not None and truncation_cutoff is not None: |
|
layer_idx = np.arange(src_dlatents.shape[1])[np.newaxis, :, np.newaxis] |
|
ones = np.ones(layer_idx.shape, dtype=np.float32) |
|
coefs = np.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones) |
|
src_dlatents_np=lerp(dlatent_avg, src_dlatents, coefs) |
|
src_dlatents=src_dlatents_np[:,0,:].astype('float32') |
|
dlatents[(i*num_once):((i+1)*num_once),:]=src_dlatents |
|
print('get all z and w') |
|
|
|
tmp='./npy/'+dataset_name+'/W' |
|
np.save(tmp,dlatents) |
|
|
|
|
|
def GetImg(Gs,num_img,num_once,dataset_name,save_name='images'): |
|
print('Generate Image') |
|
tmp='./npy/'+dataset_name+'/W.npy' |
|
dlatents=np.load(tmp) |
|
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) |
|
|
|
all_images=[] |
|
for i in range(int(num_img/num_once)): |
|
print(i) |
|
images=[] |
|
for k in range(num_once): |
|
tmp=dlatents[i*num_once+k] |
|
tmp=tmp[None,None,:] |
|
tmp=np.tile(tmp,(1,Gs.components.synthesis.input_shape[1],1)) |
|
image2= Gs.components.synthesis.run(tmp, randomize_noise=False, output_transform=fmt) |
|
images.append(image2) |
|
|
|
images=np.concatenate(images) |
|
|
|
all_images.append(images) |
|
|
|
all_images=np.concatenate(all_images) |
|
|
|
tmp='./npy/'+dataset_name+'/'+save_name |
|
np.save(tmp,all_images) |
|
|
|
def GetS(dataset_name,num_img): |
|
print('Generate S') |
|
tmp='./npy/'+dataset_name+'/W.npy' |
|
dlatents=np.load(tmp)[:num_img] |
|
|
|
with tf.Session() as sess: |
|
init = tf.global_variables_initializer() |
|
sess.run(init) |
|
|
|
Gs=LoadModel(dataset_name) |
|
Gs.print_layers() |
|
select_layers1=GetSNames(suffix=None) |
|
dlatents=dlatents[:,None,:] |
|
dlatents=np.tile(dlatents,(1,Gs.components.synthesis.input_shape[1],1)) |
|
|
|
all_s = sess.run( |
|
select_layers1, |
|
feed_dict={'G_synthesis_1/dlatents_in:0': dlatents}) |
|
|
|
layer_names=[layer.name for layer in select_layers1] |
|
save_tmp=[layer_names,all_s] |
|
return save_tmp |
|
|
|
|
|
|
|
|
|
def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False): |
|
"""Convert a minibatch of images from float32 to uint8 with configurable dynamic range. |
|
Can be used as an output transformation for Network.run(). |
|
""" |
|
if nchw_to_nhwc: |
|
images = np.transpose(images, [0, 2, 3, 1]) |
|
|
|
scale = 255 / (drange[1] - drange[0]) |
|
images = images * scale + (0.5 - drange[0] * scale) |
|
|
|
np.clip(images, 0, 255, out=images) |
|
images=images.astype('uint8') |
|
return images |
|
|
|
|
|
def GetCodeMS(dlatents): |
|
m=[] |
|
std=[] |
|
for i in range(len(dlatents)): |
|
tmp= dlatents[i] |
|
tmp_mean=tmp.mean(axis=0) |
|
tmp_std=tmp.std(axis=0) |
|
m.append(tmp_mean) |
|
std.append(tmp_std) |
|
return m,std |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Process some integers.') |
|
|
|
parser.add_argument('--dataset_name',type=str,default='ffhq', |
|
help='name of dataset, for example, ffhq') |
|
parser.add_argument('--code_type',choices=['w','s','s_mean_std'],default='w') |
|
|
|
args = parser.parse_args() |
|
random_state=5 |
|
num_img=100_000 |
|
num_once=1_000 |
|
dataset_name=args.dataset_name |
|
|
|
if not os.path.isfile('./model/'+dataset_name+'.pkl'): |
|
url='https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/' |
|
name='stylegan2-'+dataset_name+'-config-f.pkl' |
|
os.system('wget ' +url+name + ' -P ./model/') |
|
os.system('mv ./model/'+name+' ./model/'+dataset_name+'.pkl') |
|
|
|
if not os.path.isdir('./npy/'+dataset_name): |
|
os.system('mkdir ./npy/'+dataset_name) |
|
|
|
if args.code_type=='w': |
|
Gs=LoadModel(dataset_name=dataset_name) |
|
GetCode(Gs,random_state,num_img,num_once,dataset_name) |
|
|
|
elif args.code_type=='s': |
|
save_name='S' |
|
save_tmp=GetS(dataset_name,num_img=2_000) |
|
tmp='./npy/'+dataset_name+'/'+save_name |
|
with open(tmp, "wb") as fp: |
|
pickle.dump(save_tmp, fp) |
|
|
|
elif args.code_type=='s_mean_std': |
|
save_tmp=GetS(dataset_name,num_img=num_img) |
|
dlatents=save_tmp[1] |
|
m,std=GetCodeMS(dlatents) |
|
save_tmp=[m,std] |
|
save_name='S_mean_std' |
|
tmp='./npy/'+dataset_name+'/'+save_name |
|
with open(tmp, "wb") as fp: |
|
pickle.dump(save_tmp, fp) |
|
|
|
|
|
|
|
|
|
|
|
|