|
|
|
|
|
from manipulate import Manipulator |
|
import tensorflow as tf |
|
import numpy as np |
|
import torch |
|
import clip |
|
from MapTS import GetBoundary,GetDt |
|
|
|
class StyleCLIP(): |
|
|
|
def __init__(self,dataset_name='ffhq'): |
|
print('load clip') |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model, preprocess = clip.load("ViT-B/32", device=device) |
|
self.LoadData(dataset_name) |
|
|
|
def LoadData(self, dataset_name): |
|
tf.keras.backend.clear_session() |
|
M=Manipulator(dataset_name=dataset_name) |
|
np.set_printoptions(suppress=True) |
|
fs3=np.load('./npy/'+dataset_name+'/fs3.npy') |
|
|
|
self.M=M |
|
self.fs3=fs3 |
|
|
|
w_plus=np.load('./data/'+dataset_name+'/w_plus.npy') |
|
self.M.dlatents=M.W2S(w_plus) |
|
|
|
if dataset_name=='ffhq': |
|
self.c_threshold=20 |
|
else: |
|
self.c_threshold=100 |
|
self.SetInitP() |
|
|
|
def SetInitP(self): |
|
self.M.alpha=[3] |
|
self.M.num_images=1 |
|
|
|
self.target='' |
|
self.neutral='' |
|
self.GetDt2() |
|
img_index=0 |
|
self.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.M.dlatents] |
|
|
|
|
|
def GetDt2(self): |
|
classnames=[self.target,self.neutral] |
|
dt=GetDt(classnames,self.model) |
|
|
|
self.dt=dt |
|
num_cs=[] |
|
betas=np.arange(0.1,0.3,0.01) |
|
for i in range(len(betas)): |
|
boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=betas[i]) |
|
print(betas[i]) |
|
num_cs.append(num_c) |
|
|
|
num_cs=np.array(num_cs) |
|
select=num_cs>self.c_threshold |
|
|
|
if sum(select)==0: |
|
self.beta=0.1 |
|
else: |
|
self.beta=betas[select][-1] |
|
|
|
|
|
def GetCode(self): |
|
boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=self.beta) |
|
codes=self.M.MSCode(self.M.dlatent_tmp,boundary_tmp2) |
|
return codes |
|
|
|
def GetImg(self): |
|
|
|
codes=self.GetCode() |
|
out=self.M.GenerateImg(codes) |
|
img=out[0,0] |
|
return img |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
style_clip=StyleCLIP() |
|
self=style_clip |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|