|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
import clip |
|
from PIL import Image |
|
import copy |
|
from manipulate import Manipulator |
|
import argparse |
|
|
|
def GetImgF(out,model,preprocess): |
|
imgs=out |
|
imgs1=imgs.reshape([-1]+list(imgs.shape[2:])) |
|
|
|
tmp=[] |
|
for i in range(len(imgs1)): |
|
|
|
img=Image.fromarray(imgs1[i]) |
|
image = preprocess(img).unsqueeze(0).to(device) |
|
tmp.append(image) |
|
|
|
image=torch.cat(tmp) |
|
with torch.no_grad(): |
|
image_features = model.encode_image(image) |
|
|
|
image_features1=image_features.cpu().numpy() |
|
image_features1=image_features1.reshape(list(imgs.shape[:2])+[512]) |
|
|
|
return image_features1 |
|
|
|
def GetFs(fs): |
|
tmp=np.linalg.norm(fs,axis=-1) |
|
fs1=fs/tmp[:,:,:,None] |
|
fs2=fs1[:,:,1,:]-fs1[:,:,0,:] |
|
fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None] |
|
fs3=fs3.mean(axis=1) |
|
fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None] |
|
return fs3 |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description='Process some integers.') |
|
|
|
parser.add_argument('--dataset_name',type=str,default='cat', |
|
help='name of dataset, for example, ffhq') |
|
args = parser.parse_args() |
|
dataset_name=args.dataset_name |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model, preprocess = clip.load("ViT-B/32", device=device) |
|
|
|
M=Manipulator(dataset_name=dataset_name) |
|
np.set_printoptions(suppress=True) |
|
print(M.dataset_name) |
|
|
|
img_sindex=0 |
|
num_images=100 |
|
dlatents_o=[] |
|
tmp=img_sindex*num_images |
|
for i in range(len(M.dlatents)): |
|
tmp1=M.dlatents[i][tmp:(tmp+num_images)] |
|
dlatents_o.append(tmp1) |
|
|
|
|
|
all_f=[] |
|
M.alpha=[-5,5] |
|
M.step=2 |
|
M.num_images=num_images |
|
select=np.array(M.mindexs)<=16 |
|
mindexs2=np.array(M.mindexs)[select] |
|
for lindex in mindexs2: |
|
print(lindex) |
|
num_c=M.dlatents[lindex].shape[1] |
|
for cindex in range(num_c): |
|
|
|
M.dlatents=copy.copy(dlatents_o) |
|
M.dlatents[lindex][:,cindex]=M.code_mean[lindex][cindex] |
|
|
|
M.manipulate_layers=[lindex] |
|
codes,out=M.EditOneC(cindex) |
|
image_features1=GetImgF(out,model,preprocess) |
|
all_f.append(image_features1) |
|
|
|
all_f=np.array(all_f) |
|
|
|
fs3=GetFs(all_f) |
|
|
|
|
|
file_path='./npy/'+M.dataset_name+'/' |
|
np.save(file_path+'fs3',fs3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|