ucalyptus's picture
simp
2d7efb8
raw history blame
No virus
7.98 kB
import os
import os.path
import pickle
import numpy as np
import tensorflow as tf
from dnnlib import tflib
from global_directions.utils.visualizer import HtmlPageVisualizer
def Vis(bname,suffix,out,rownames=None,colnames=None):
num_images=out.shape[0]
step=out.shape[1]
if colnames is None:
colnames=[f'Step {i:02d}' for i in range(1, step + 1)]
if rownames is None:
rownames=[str(i) for i in range(num_images)]
visualizer = HtmlPageVisualizer(
num_rows=num_images, num_cols=step + 1, viz_size=256)
visualizer.set_headers(
['Name'] +colnames)
for i in range(num_images):
visualizer.set_cell(i, 0, text=rownames[i])
for i in range(num_images):
for k in range(step):
image=out[i,k,:,:,:]
visualizer.set_cell(i, 1+k, image=image)
# Save results.
visualizer.save(f'./html/'+bname+'_'+suffix+'.html')
def LoadData(img_path):
tmp=img_path+'S'
with open(tmp, "rb") as fp: #Pickling
s_names,all_s=pickle.load( fp)
dlatents=all_s
pindexs=[]
mindexs=[]
for i in range(len(s_names)):
name=s_names[i]
if not('ToRGB' in name):
mindexs.append(i)
else:
pindexs.append(i)
tmp=img_path+'S_mean_std'
with open(tmp, "rb") as fp: #Pickling
m,std=pickle.load( fp)
return dlatents,s_names,mindexs,pindexs,m,std
def LoadModel(model_path,model_name):
# Initialize TensorFlow.
tflib.init_tf()
tmp=os.path.join(model_path,model_name)
with open(tmp, 'rb') as f:
_, _, Gs = pickle.load(f)
Gs.print_layers()
return Gs
def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False):
"""Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
Can be used as an output transformation for Network.run().
"""
if nchw_to_nhwc:
images = np.transpose(images, [0, 2, 3, 1])
scale = 255 / (drange[1] - drange[0])
images = images * scale + (0.5 - drange[0] * scale)
np.clip(images, 0, 255, out=images)
images=images.astype('uint8')
return images
def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
"""Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
Can be used as an input transformation for Network.run().
"""
if nhwc_to_nchw:
images=np.rollaxis(images, 3, 1)
return images/ 255 *(drange[1] - drange[0])+ drange[0]
class Manipulator():
def __init__(self,dataset_name='ffhq'):
self.file_path='./'
self.img_path=self.file_path+'npy/'+dataset_name+'/'
self.model_path=self.file_path+'model/'
self.dataset_name=dataset_name
self.model_name=dataset_name+'.pkl'
self.alpha=[0] #manipulation strength
self.num_images=10
self.img_index=0 #which image to start
self.viz_size=256
self.manipulate_layers=None #which layer to manipulate, list
self.dlatents,self.s_names,self.mindexs,self.pindexs,self.code_mean,self.code_std=LoadData(self.img_path)
self.sess=tf.InteractiveSession()
init = tf.global_variables_initializer()
self.sess.run(init)
self.Gs=LoadModel(self.model_path,self.model_name)
self.num_layers=len(self.dlatents)
self.Vis=Vis
self.noise_constant={}
for i in range(len(self.s_names)):
tmp1=self.s_names[i].split('/')
if not 'ToRGB' in tmp1:
tmp1[-1]='random_normal:0'
size=int(tmp1[1].split('x')[0])
tmp1='/'.join(tmp1)
tmp=(1,1,size,size)
self.noise_constant[tmp1]=np.random.random(tmp)
tmp=self.Gs.components.synthesis.input_shape[1]
d={}
d['G_synthesis_1/dlatents_in:0']=np.zeros([1,tmp,512])
names=list(self.noise_constant.keys())
tmp=tflib.run(names,d)
for i in range(len(names)):
self.noise_constant[names[i]]=tmp[i]
self.fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
self.img_size=self.Gs.output_shape[-1]
def GenerateImg(self,codes):
num_images,step=codes[0].shape[:2]
out=np.zeros((num_images,step,self.img_size,self.img_size,3),dtype='uint8')
for i in range(num_images):
for k in range(step):
d={}
for m in range(len(self.s_names)):
d[self.s_names[m]]=codes[m][i,k][None,:] #need to change
d['G_synthesis_1/4x4/Const/Shape:0']=np.array([1,18, 512], dtype=np.int32)
d.update(self.noise_constant)
img=tflib.run('G_synthesis_1/images_out:0', d)
image=convert_images_to_uint8(img, nchw_to_nhwc=True)
out[i,k,:,:,:]=image[0]
return out
def MSCode(self,dlatent_tmp,boundary_tmp):
step=len(self.alpha)
dlatent_tmp1=[tmp.reshape((self.num_images,-1)) for tmp in dlatent_tmp]
dlatent_tmp2=[np.tile(tmp[:,None],(1,step,1)) for tmp in dlatent_tmp1] # (10, 7, 512)
l=np.array(self.alpha)
l=l.reshape(
[step if axis == 1 else 1 for axis in range(dlatent_tmp2[0].ndim)])
if type(self.manipulate_layers)==int:
tmp=[self.manipulate_layers]
elif type(self.manipulate_layers)==list:
tmp=self.manipulate_layers
elif self.manipulate_layers is None:
tmp=np.arange(len(boundary_tmp))
else:
raise ValueError('manipulate_layers is wrong')
for i in tmp:
dlatent_tmp2[i]+=l*boundary_tmp[i]
codes=[]
for i in range(len(dlatent_tmp2)):
tmp=list(dlatent_tmp[i].shape)
tmp.insert(1,step)
codes.append(dlatent_tmp2[i].reshape(tmp))
return codes
def EditOne(self,bname,dlatent_tmp=None):
if dlatent_tmp==None:
dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents]
boundary_tmp=[]
for i in range(len(self.boundary)):
tmp=self.boundary[i]
if len(tmp)<=bname:
boundary_tmp.append([])
else:
boundary_tmp.append(tmp[bname])
codes=self.MSCode(dlatent_tmp,boundary_tmp)
out=self.GenerateImg(codes)
return codes,out
def EditOneC(self,cindex,dlatent_tmp=None):
if dlatent_tmp==None:
dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents]
boundary_tmp=[[] for i in range(len(self.dlatents))]
#'only manipulate 1 layer and one channel'
assert len(self.manipulate_layers)==1
ml=self.manipulate_layers[0]
tmp=dlatent_tmp[ml].shape[1] #ada
tmp1=np.zeros(tmp)
tmp1[cindex]=self.code_std[ml][cindex] #1
boundary_tmp[ml]=tmp1
codes=self.MSCode(dlatent_tmp,boundary_tmp)
out=self.GenerateImg(codes)
return codes,out
def W2S(self,dlatent_tmp):
all_s = self.sess.run(
self.s_names,
feed_dict={'G_synthesis_1/dlatents_in:0': dlatent_tmp})
return all_s
#%%
if __name__ == "__main__":
M=Manipulator(dataset_name='ffhq')
#%%
M.alpha=[-5,0,5]
M.num_images=20
lindex,cindex=6,501
M.manipulate_layers=[lindex]
codes,out=M.EditOneC(cindex) #dlatent_tmp
tmp=str(M.manipulate_layers)+'_'+str(cindex)
M.Vis(tmp,'c',out)