trysem's picture
Duplicate from ucalyptus/PTI
4d9fdb5
raw
history blame contribute delete
No virus
7.1 kB
import os
import pickle
import numpy as np
from dnnlib import tflib
import tensorflow as tf
import argparse
def LoadModel(dataset_name):
# Initialize TensorFlow.
tflib.init_tf()
model_path='./model/'
model_name=dataset_name+'.pkl'
tmp=os.path.join(model_path,model_name)
with open(tmp, 'rb') as f:
_, _, Gs = pickle.load(f)
return Gs
def lerp(a,b,t):
return a + (b - a) * t
#stylegan-ada
def SelectName(layer_name,suffix):
if suffix==None:
tmp1='add:0' in layer_name
tmp2='shape=(?,' in layer_name
tmp4='G_synthesis_1' in layer_name
tmp= tmp1 and tmp2 and tmp4
else:
tmp1=('/Conv0_up'+suffix) in layer_name
tmp2=('/Conv1'+suffix) in layer_name
tmp3=('4x4/Conv'+suffix) in layer_name
tmp4='G_synthesis_1' in layer_name
tmp5=('/ToRGB'+suffix) in layer_name
tmp= (tmp1 or tmp2 or tmp3 or tmp5) and tmp4
return tmp
def GetSNames(suffix):
#get style tensor name
with tf.Session() as sess:
op = sess.graph.get_operations()
layers=[m.values() for m in op]
select_layers=[]
for layer in layers:
layer_name=str(layer)
if SelectName(layer_name,suffix):
select_layers.append(layer[0])
return select_layers
def SelectName2(layer_name):
tmp1='mod_bias' in layer_name
tmp2='mod_weight' in layer_name
tmp3='ToRGB' in layer_name
tmp= (tmp1 or tmp2) and (not tmp3)
return tmp
def GetKName(Gs):
layers=[var for name, var in Gs.components.synthesis.vars.items()]
select_layers=[]
for layer in layers:
layer_name=str(layer)
if SelectName2(layer_name):
select_layers.append(layer)
return select_layers
def GetCode(Gs,random_state,num_img,num_once,dataset_name):
rnd = np.random.RandomState(random_state) #5
truncation_psi=0.7
truncation_cutoff=8
dlatent_avg=Gs.get_var('dlatent_avg')
dlatents=np.zeros((num_img,512),dtype='float32')
for i in range(int(num_img/num_once)):
src_latents = rnd.randn(num_once, Gs.input_shape[1])
src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
# Apply truncation trick.
if truncation_psi is not None and truncation_cutoff is not None:
layer_idx = np.arange(src_dlatents.shape[1])[np.newaxis, :, np.newaxis]
ones = np.ones(layer_idx.shape, dtype=np.float32)
coefs = np.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones)
src_dlatents_np=lerp(dlatent_avg, src_dlatents, coefs)
src_dlatents=src_dlatents_np[:,0,:].astype('float32')
dlatents[(i*num_once):((i+1)*num_once),:]=src_dlatents
print('get all z and w')
tmp='./npy/'+dataset_name+'/W'
np.save(tmp,dlatents)
def GetImg(Gs,num_img,num_once,dataset_name,save_name='images'):
print('Generate Image')
tmp='./npy/'+dataset_name+'/W.npy'
dlatents=np.load(tmp)
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
all_images=[]
for i in range(int(num_img/num_once)):
print(i)
images=[]
for k in range(num_once):
tmp=dlatents[i*num_once+k]
tmp=tmp[None,None,:]
tmp=np.tile(tmp,(1,Gs.components.synthesis.input_shape[1],1))
image2= Gs.components.synthesis.run(tmp, randomize_noise=False, output_transform=fmt)
images.append(image2)
images=np.concatenate(images)
all_images.append(images)
all_images=np.concatenate(all_images)
tmp='./npy/'+dataset_name+'/'+save_name
np.save(tmp,all_images)
def GetS(dataset_name,num_img):
print('Generate S')
tmp='./npy/'+dataset_name+'/W.npy'
dlatents=np.load(tmp)[:num_img]
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
Gs=LoadModel(dataset_name)
Gs.print_layers() #for ada
select_layers1=GetSNames(suffix=None) #None,'/mul_1:0','/mod_weight/read:0','/MatMul:0'
dlatents=dlatents[:,None,:]
dlatents=np.tile(dlatents,(1,Gs.components.synthesis.input_shape[1],1))
all_s = sess.run(
select_layers1,
feed_dict={'G_synthesis_1/dlatents_in:0': dlatents})
layer_names=[layer.name for layer in select_layers1]
save_tmp=[layer_names,all_s]
return save_tmp
def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False):
"""Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
Can be used as an output transformation for Network.run().
"""
if nchw_to_nhwc:
images = np.transpose(images, [0, 2, 3, 1])
scale = 255 / (drange[1] - drange[0])
images = images * scale + (0.5 - drange[0] * scale)
np.clip(images, 0, 255, out=images)
images=images.astype('uint8')
return images
def GetCodeMS(dlatents):
m=[]
std=[]
for i in range(len(dlatents)):
tmp= dlatents[i]
tmp_mean=tmp.mean(axis=0)
tmp_std=tmp.std(axis=0)
m.append(tmp_mean)
std.append(tmp_std)
return m,std
#%%
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--dataset_name',type=str,default='ffhq',
help='name of dataset, for example, ffhq')
parser.add_argument('--code_type',choices=['w','s','s_mean_std'],default='w')
args = parser.parse_args()
random_state=5
num_img=100_000
num_once=1_000
dataset_name=args.dataset_name
if not os.path.isfile('./model/'+dataset_name+'.pkl'):
url='https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/'
name='stylegan2-'+dataset_name+'-config-f.pkl'
os.system('wget ' +url+name + ' -P ./model/')
os.system('mv ./model/'+name+' ./model/'+dataset_name+'.pkl')
if not os.path.isdir('./npy/'+dataset_name):
os.system('mkdir ./npy/'+dataset_name)
if args.code_type=='w':
Gs=LoadModel(dataset_name=dataset_name)
GetCode(Gs,random_state,num_img,num_once,dataset_name)
# GetImg(Gs,num_img=num_img,num_once=num_once,dataset_name=dataset_name,save_name='images_100K') #no need
elif args.code_type=='s':
save_name='S'
save_tmp=GetS(dataset_name,num_img=2_000)
tmp='./npy/'+dataset_name+'/'+save_name
with open(tmp, "wb") as fp:
pickle.dump(save_tmp, fp)
elif args.code_type=='s_mean_std':
save_tmp=GetS(dataset_name,num_img=num_img)
dlatents=save_tmp[1]
m,std=GetCodeMS(dlatents)
save_tmp=[m,std]
save_name='S_mean_std'
tmp='./npy/'+dataset_name+'/'+save_name
with open(tmp, "wb") as fp:
pickle.dump(save_tmp, fp)