ucalyptus's picture
simp
2d7efb8
raw
history blame
No virus
2.62 kB
import numpy as np
import torch
import clip
from PIL import Image
import copy
from manipulate import Manipulator
import argparse
def GetImgF(out,model,preprocess):
imgs=out
imgs1=imgs.reshape([-1]+list(imgs.shape[2:]))
tmp=[]
for i in range(len(imgs1)):
img=Image.fromarray(imgs1[i])
image = preprocess(img).unsqueeze(0).to(device)
tmp.append(image)
image=torch.cat(tmp)
with torch.no_grad():
image_features = model.encode_image(image)
image_features1=image_features.cpu().numpy()
image_features1=image_features1.reshape(list(imgs.shape[:2])+[512])
return image_features1
def GetFs(fs):
tmp=np.linalg.norm(fs,axis=-1)
fs1=fs/tmp[:,:,:,None]
fs2=fs1[:,:,1,:]-fs1[:,:,0,:] # 5*sigma - (-5)* sigma
fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None]
fs3=fs3.mean(axis=1)
fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None]
return fs3
#%%
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--dataset_name',type=str,default='cat',
help='name of dataset, for example, ffhq')
args = parser.parse_args()
dataset_name=args.dataset_name
#%%
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
#%%
M=Manipulator(dataset_name=dataset_name)
np.set_printoptions(suppress=True)
print(M.dataset_name)
#%%
img_sindex=0
num_images=100
dlatents_o=[]
tmp=img_sindex*num_images
for i in range(len(M.dlatents)):
tmp1=M.dlatents[i][tmp:(tmp+num_images)]
dlatents_o.append(tmp1)
#%%
all_f=[]
M.alpha=[-5,5] #ffhq 5
M.step=2
M.num_images=num_images
select=np.array(M.mindexs)<=16 #below or equal to 128 resolution
mindexs2=np.array(M.mindexs)[select]
for lindex in mindexs2: #ignore ToRGB layers
print(lindex)
num_c=M.dlatents[lindex].shape[1]
for cindex in range(num_c):
M.dlatents=copy.copy(dlatents_o)
M.dlatents[lindex][:,cindex]=M.code_mean[lindex][cindex]
M.manipulate_layers=[lindex]
codes,out=M.EditOneC(cindex)
image_features1=GetImgF(out,model,preprocess)
all_f.append(image_features1)
all_f=np.array(all_f)
fs3=GetFs(all_f)
#%%
file_path='./npy/'+M.dataset_name+'/'
np.save(file_path+'fs3',fs3)