|
|
|
|
|
import os |
|
import os.path |
|
import pickle |
|
import numpy as np |
|
import tensorflow as tf |
|
from dnnlib import tflib |
|
from global_directions.utils.visualizer import HtmlPageVisualizer |
|
|
|
|
|
def Vis(bname,suffix,out,rownames=None,colnames=None): |
|
num_images=out.shape[0] |
|
step=out.shape[1] |
|
|
|
if colnames is None: |
|
colnames=[f'Step {i:02d}' for i in range(1, step + 1)] |
|
if rownames is None: |
|
rownames=[str(i) for i in range(num_images)] |
|
|
|
|
|
visualizer = HtmlPageVisualizer( |
|
num_rows=num_images, num_cols=step + 1, viz_size=256) |
|
visualizer.set_headers( |
|
['Name'] +colnames) |
|
|
|
for i in range(num_images): |
|
visualizer.set_cell(i, 0, text=rownames[i]) |
|
|
|
for i in range(num_images): |
|
for k in range(step): |
|
image=out[i,k,:,:,:] |
|
visualizer.set_cell(i, 1+k, image=image) |
|
|
|
|
|
visualizer.save(f'./html/'+bname+'_'+suffix+'.html') |
|
|
|
|
|
|
|
|
|
def LoadData(img_path): |
|
tmp=img_path+'S' |
|
with open(tmp, "rb") as fp: |
|
s_names,all_s=pickle.load( fp) |
|
dlatents=all_s |
|
|
|
pindexs=[] |
|
mindexs=[] |
|
for i in range(len(s_names)): |
|
name=s_names[i] |
|
if not('ToRGB' in name): |
|
mindexs.append(i) |
|
else: |
|
pindexs.append(i) |
|
|
|
tmp=img_path+'S_mean_std' |
|
with open(tmp, "rb") as fp: |
|
m,std=pickle.load( fp) |
|
|
|
return dlatents,s_names,mindexs,pindexs,m,std |
|
|
|
|
|
def LoadModel(model_path,model_name): |
|
|
|
tflib.init_tf() |
|
tmp=os.path.join(model_path,model_name) |
|
with open(tmp, 'rb') as f: |
|
_, _, Gs = pickle.load(f) |
|
Gs.print_layers() |
|
return Gs |
|
|
|
def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False): |
|
"""Convert a minibatch of images from float32 to uint8 with configurable dynamic range. |
|
Can be used as an output transformation for Network.run(). |
|
""" |
|
if nchw_to_nhwc: |
|
images = np.transpose(images, [0, 2, 3, 1]) |
|
|
|
scale = 255 / (drange[1] - drange[0]) |
|
images = images * scale + (0.5 - drange[0] * scale) |
|
|
|
np.clip(images, 0, 255, out=images) |
|
images=images.astype('uint8') |
|
return images |
|
|
|
|
|
def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): |
|
"""Convert a minibatch of images from uint8 to float32 with configurable dynamic range. |
|
Can be used as an input transformation for Network.run(). |
|
""" |
|
if nhwc_to_nchw: |
|
images=np.rollaxis(images, 3, 1) |
|
return images/ 255 *(drange[1] - drange[0])+ drange[0] |
|
|
|
|
|
class Manipulator(): |
|
def __init__(self,dataset_name='ffhq'): |
|
self.file_path='./' |
|
self.img_path=self.file_path+'npy/'+dataset_name+'/' |
|
self.model_path=self.file_path+'model/' |
|
self.dataset_name=dataset_name |
|
self.model_name=dataset_name+'.pkl' |
|
|
|
self.alpha=[0] |
|
self.num_images=10 |
|
self.img_index=0 |
|
self.viz_size=256 |
|
self.manipulate_layers=None |
|
|
|
self.dlatents,self.s_names,self.mindexs,self.pindexs,self.code_mean,self.code_std=LoadData(self.img_path) |
|
|
|
self.sess=tf.InteractiveSession() |
|
init = tf.global_variables_initializer() |
|
self.sess.run(init) |
|
self.Gs=LoadModel(self.model_path,self.model_name) |
|
self.num_layers=len(self.dlatents) |
|
|
|
self.Vis=Vis |
|
self.noise_constant={} |
|
|
|
for i in range(len(self.s_names)): |
|
tmp1=self.s_names[i].split('/') |
|
if not 'ToRGB' in tmp1: |
|
tmp1[-1]='random_normal:0' |
|
size=int(tmp1[1].split('x')[0]) |
|
tmp1='/'.join(tmp1) |
|
tmp=(1,1,size,size) |
|
self.noise_constant[tmp1]=np.random.random(tmp) |
|
|
|
tmp=self.Gs.components.synthesis.input_shape[1] |
|
d={} |
|
d['G_synthesis_1/dlatents_in:0']=np.zeros([1,tmp,512]) |
|
names=list(self.noise_constant.keys()) |
|
tmp=tflib.run(names,d) |
|
for i in range(len(names)): |
|
self.noise_constant[names[i]]=tmp[i] |
|
|
|
self.fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) |
|
self.img_size=self.Gs.output_shape[-1] |
|
|
|
def GenerateImg(self,codes): |
|
|
|
|
|
num_images,step=codes[0].shape[:2] |
|
|
|
|
|
out=np.zeros((num_images,step,self.img_size,self.img_size,3),dtype='uint8') |
|
for i in range(num_images): |
|
for k in range(step): |
|
d={} |
|
for m in range(len(self.s_names)): |
|
d[self.s_names[m]]=codes[m][i,k][None,:] |
|
d['G_synthesis_1/4x4/Const/Shape:0']=np.array([1,18, 512], dtype=np.int32) |
|
d.update(self.noise_constant) |
|
img=tflib.run('G_synthesis_1/images_out:0', d) |
|
image=convert_images_to_uint8(img, nchw_to_nhwc=True) |
|
out[i,k,:,:,:]=image[0] |
|
return out |
|
|
|
|
|
|
|
def MSCode(self,dlatent_tmp,boundary_tmp): |
|
|
|
step=len(self.alpha) |
|
dlatent_tmp1=[tmp.reshape((self.num_images,-1)) for tmp in dlatent_tmp] |
|
dlatent_tmp2=[np.tile(tmp[:,None],(1,step,1)) for tmp in dlatent_tmp1] |
|
|
|
l=np.array(self.alpha) |
|
l=l.reshape( |
|
[step if axis == 1 else 1 for axis in range(dlatent_tmp2[0].ndim)]) |
|
|
|
if type(self.manipulate_layers)==int: |
|
tmp=[self.manipulate_layers] |
|
elif type(self.manipulate_layers)==list: |
|
tmp=self.manipulate_layers |
|
elif self.manipulate_layers is None: |
|
tmp=np.arange(len(boundary_tmp)) |
|
else: |
|
raise ValueError('manipulate_layers is wrong') |
|
|
|
for i in tmp: |
|
dlatent_tmp2[i]+=l*boundary_tmp[i] |
|
|
|
codes=[] |
|
for i in range(len(dlatent_tmp2)): |
|
tmp=list(dlatent_tmp[i].shape) |
|
tmp.insert(1,step) |
|
codes.append(dlatent_tmp2[i].reshape(tmp)) |
|
return codes |
|
|
|
|
|
def EditOne(self,bname,dlatent_tmp=None): |
|
if dlatent_tmp==None: |
|
dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents] |
|
|
|
boundary_tmp=[] |
|
for i in range(len(self.boundary)): |
|
tmp=self.boundary[i] |
|
if len(tmp)<=bname: |
|
boundary_tmp.append([]) |
|
else: |
|
boundary_tmp.append(tmp[bname]) |
|
|
|
codes=self.MSCode(dlatent_tmp,boundary_tmp) |
|
|
|
out=self.GenerateImg(codes) |
|
return codes,out |
|
|
|
def EditOneC(self,cindex,dlatent_tmp=None): |
|
if dlatent_tmp==None: |
|
dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents] |
|
|
|
boundary_tmp=[[] for i in range(len(self.dlatents))] |
|
|
|
|
|
assert len(self.manipulate_layers)==1 |
|
|
|
ml=self.manipulate_layers[0] |
|
tmp=dlatent_tmp[ml].shape[1] |
|
tmp1=np.zeros(tmp) |
|
tmp1[cindex]=self.code_std[ml][cindex] |
|
boundary_tmp[ml]=tmp1 |
|
|
|
codes=self.MSCode(dlatent_tmp,boundary_tmp) |
|
out=self.GenerateImg(codes) |
|
return codes,out |
|
|
|
|
|
def W2S(self,dlatent_tmp): |
|
|
|
all_s = self.sess.run( |
|
self.s_names, |
|
feed_dict={'G_synthesis_1/dlatents_in:0': dlatent_tmp}) |
|
return all_s |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
M=Manipulator(dataset_name='ffhq') |
|
|
|
|
|
|
|
M.alpha=[-5,0,5] |
|
M.num_images=20 |
|
lindex,cindex=6,501 |
|
|
|
M.manipulate_layers=[lindex] |
|
codes,out=M.EditOneC(cindex) |
|
tmp=str(M.manipulate_layers)+'_'+str(cindex) |
|
M.Vis(tmp,'c',out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|