MarkoVidrih's picture
Duplicate from DragGan/DragGan-Inversion
42d4082
from manipulate import Manipulator
import tensorflow as tf
import numpy as np
import torch
import clip
from MapTS import GetBoundary,GetDt
class StyleCLIP():
def __init__(self,dataset_name='ffhq'):
print('load clip')
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model, preprocess = clip.load("ViT-B/32", device=device)
self.LoadData(dataset_name)
def LoadData(self, dataset_name):
tf.keras.backend.clear_session()
M=Manipulator(dataset_name=dataset_name)
np.set_printoptions(suppress=True)
fs3=np.load('./npy/'+dataset_name+'/fs3.npy')
self.M=M
self.fs3=fs3
w_plus=np.load('./data/'+dataset_name+'/w_plus.npy')
self.M.dlatents=M.W2S(w_plus)
if dataset_name=='ffhq':
self.c_threshold=20
else:
self.c_threshold=100
self.SetInitP()
def SetInitP(self):
self.M.alpha=[3]
self.M.num_images=1
self.target=''
self.neutral=''
self.GetDt2()
img_index=0
self.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.M.dlatents]
def GetDt2(self):
classnames=[self.target,self.neutral]
dt=GetDt(classnames,self.model)
self.dt=dt
num_cs=[]
betas=np.arange(0.1,0.3,0.01)
for i in range(len(betas)):
boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=betas[i])
print(betas[i])
num_cs.append(num_c)
num_cs=np.array(num_cs)
select=num_cs>self.c_threshold
if sum(select)==0:
self.beta=0.1
else:
self.beta=betas[select][-1]
def GetCode(self):
boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=self.beta)
codes=self.M.MSCode(self.M.dlatent_tmp,boundary_tmp2)
return codes
def GetImg(self):
codes=self.GetCode()
out=self.M.GenerateImg(codes)
img=out[0,0]
return img
#%%
if __name__ == "__main__":
style_clip=StyleCLIP()
self=style_clip