File size: 2,623 Bytes
4d9fdb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109



import numpy as np 
import torch
import clip
from PIL import Image
import copy
from manipulate import Manipulator
import argparse

def GetImgF(out,model,preprocess):
    imgs=out
    imgs1=imgs.reshape([-1]+list(imgs.shape[2:]))
    
    tmp=[]
    for i in range(len(imgs1)):
        
        img=Image.fromarray(imgs1[i])
        image = preprocess(img).unsqueeze(0).to(device)
        tmp.append(image)
    
    image=torch.cat(tmp)
    with torch.no_grad():
        image_features = model.encode_image(image)
    
    image_features1=image_features.cpu().numpy()
    image_features1=image_features1.reshape(list(imgs.shape[:2])+[512])
    
    return image_features1

def GetFs(fs):
    tmp=np.linalg.norm(fs,axis=-1)
    fs1=fs/tmp[:,:,:,None]
    fs2=fs1[:,:,1,:]-fs1[:,:,0,:]  # 5*sigma - (-5)* sigma
    fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None]
    fs3=fs3.mean(axis=1)
    fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None]
    return fs3

#%%
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Process some integers.')
    
    parser.add_argument('--dataset_name',type=str,default='cat',
                    help='name of dataset, for example, ffhq')
    args = parser.parse_args()
    dataset_name=args.dataset_name
    
    #%%
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)
    #%%
    M=Manipulator(dataset_name=dataset_name)
    np.set_printoptions(suppress=True)
    print(M.dataset_name)
    #%%
    img_sindex=0
    num_images=100
    dlatents_o=[]
    tmp=img_sindex*num_images
    for i in range(len(M.dlatents)):
        tmp1=M.dlatents[i][tmp:(tmp+num_images)]
        dlatents_o.append(tmp1)
    #%%
    
    all_f=[]
    M.alpha=[-5,5] #ffhq 5
    M.step=2
    M.num_images=num_images
    select=np.array(M.mindexs)<=16 #below or equal to 128 resolution 
    mindexs2=np.array(M.mindexs)[select]
    for lindex in mindexs2: #ignore ToRGB layers
        print(lindex)
        num_c=M.dlatents[lindex].shape[1]
        for cindex in range(num_c):
            
            M.dlatents=copy.copy(dlatents_o)
            M.dlatents[lindex][:,cindex]=M.code_mean[lindex][cindex]
            
            M.manipulate_layers=[lindex]
            codes,out=M.EditOneC(cindex) 
            image_features1=GetImgF(out,model,preprocess)
            all_f.append(image_features1)
    
    all_f=np.array(all_f)
    
    fs3=GetFs(all_f)
    
    #%%
    file_path='./npy/'+M.dataset_name+'/'
    np.save(file_path+'fs3',fs3)