|
|
|
|
|
""" |
|
Created on Thu Feb 4 17:36:31 2021 |
|
|
|
@author: wuzongze |
|
""" |
|
|
|
import os |
|
|
|
|
|
|
|
import sys |
|
|
|
|
|
|
|
import tensorflow as tf |
|
|
|
import numpy as np |
|
import torch |
|
import clip |
|
from PIL import Image |
|
import pickle |
|
import copy |
|
import matplotlib.pyplot as plt |
|
|
|
def GetAlign(out,dt,model,preprocess): |
|
imgs=out |
|
imgs1=imgs.reshape([-1]+list(imgs.shape[2:])) |
|
|
|
tmp=[] |
|
for i in range(len(imgs1)): |
|
|
|
img=Image.fromarray(imgs1[i]) |
|
image = preprocess(img).unsqueeze(0).to(device) |
|
tmp.append(image) |
|
|
|
image=torch.cat(tmp) |
|
|
|
with torch.no_grad(): |
|
image_features = model.encode_image(image) |
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
|
|
image_features1=image_features.cpu().numpy() |
|
|
|
image_features1=image_features1.reshape(list(imgs.shape[:2])+[512]) |
|
|
|
fd=image_features1[:,1:,:]-image_features1[:,:-1,:] |
|
|
|
fd1=fd.reshape([-1,512]) |
|
fd2=fd1/np.linalg.norm(fd1,axis=1)[:,None] |
|
|
|
tmp=np.dot(fd2,dt) |
|
m=tmp.mean() |
|
acc=np.sum(tmp>0)/len(tmp) |
|
print(m,acc) |
|
return m,acc |
|
|
|
|
|
def SplitS(ds_p,M,if_std): |
|
all_ds=[] |
|
start=0 |
|
for i in M.mindexs: |
|
tmp=M.dlatents[i].shape[1] |
|
end=start+tmp |
|
tmp=ds_p[start:end] |
|
|
|
|
|
all_ds.append(tmp) |
|
start=end |
|
|
|
all_ds2=[] |
|
tmp_index=0 |
|
for i in range(len(M.s_names)): |
|
if (not 'RGB' in M.s_names[i]) and (not len(all_ds[tmp_index])==0): |
|
|
|
|
|
|
|
|
|
|
|
if if_std: |
|
tmp=all_ds[tmp_index]*M.code_std[i] |
|
else: |
|
tmp=all_ds[tmp_index] |
|
|
|
all_ds2.append(tmp) |
|
tmp_index+=1 |
|
else: |
|
tmp=np.zeros(len(M.dlatents[i][0])) |
|
all_ds2.append(tmp) |
|
return all_ds2 |
|
|
|
|
|
imagenet_templates = [ |
|
'a bad photo of a {}.', |
|
|
|
'a sculpture of a {}.', |
|
'a photo of the hard to see {}.', |
|
'a low resolution photo of the {}.', |
|
'a rendering of a {}.', |
|
'graffiti of a {}.', |
|
'a bad photo of the {}.', |
|
'a cropped photo of the {}.', |
|
'a tattoo of a {}.', |
|
'the embroidered {}.', |
|
'a photo of a hard to see {}.', |
|
'a bright photo of a {}.', |
|
'a photo of a clean {}.', |
|
'a photo of a dirty {}.', |
|
'a dark photo of the {}.', |
|
'a drawing of a {}.', |
|
'a photo of my {}.', |
|
'the plastic {}.', |
|
'a photo of the cool {}.', |
|
'a close-up photo of a {}.', |
|
'a black and white photo of the {}.', |
|
'a painting of the {}.', |
|
'a painting of a {}.', |
|
'a pixelated photo of the {}.', |
|
'a sculpture of the {}.', |
|
'a bright photo of the {}.', |
|
'a cropped photo of a {}.', |
|
'a plastic {}.', |
|
'a photo of the dirty {}.', |
|
'a jpeg corrupted photo of a {}.', |
|
'a blurry photo of the {}.', |
|
'a photo of the {}.', |
|
'a good photo of the {}.', |
|
'a rendering of the {}.', |
|
'a {} in a video game.', |
|
'a photo of one {}.', |
|
'a doodle of a {}.', |
|
'a close-up photo of the {}.', |
|
'a photo of a {}.', |
|
'the origami {}.', |
|
'the {} in a video game.', |
|
'a sketch of a {}.', |
|
'a doodle of the {}.', |
|
'a origami {}.', |
|
'a low resolution photo of a {}.', |
|
'the toy {}.', |
|
'a rendition of the {}.', |
|
'a photo of the clean {}.', |
|
'a photo of a large {}.', |
|
'a rendition of a {}.', |
|
'a photo of a nice {}.', |
|
'a photo of a weird {}.', |
|
'a blurry photo of a {}.', |
|
'a cartoon {}.', |
|
'art of a {}.', |
|
'a sketch of the {}.', |
|
'a embroidered {}.', |
|
'a pixelated photo of a {}.', |
|
'itap of the {}.', |
|
'a jpeg corrupted photo of the {}.', |
|
'a good photo of a {}.', |
|
'a plushie {}.', |
|
'a photo of the nice {}.', |
|
'a photo of the small {}.', |
|
'a photo of the weird {}.', |
|
'the cartoon {}.', |
|
'art of the {}.', |
|
'a drawing of the {}.', |
|
'a photo of the large {}.', |
|
'a black and white photo of a {}.', |
|
'the plushie {}.', |
|
'a dark photo of a {}.', |
|
'itap of a {}.', |
|
'graffiti of the {}.', |
|
'a toy {}.', |
|
'itap of my {}.', |
|
'a photo of a cool {}.', |
|
'a photo of a small {}.', |
|
'a tattoo of the {}.', |
|
] |
|
|
|
|
|
def zeroshot_classifier(classnames, templates,model): |
|
with torch.no_grad(): |
|
zeroshot_weights = [] |
|
for classname in classnames: |
|
texts = [template.format(classname) for template in templates] |
|
texts = clip.tokenize(texts).cuda() |
|
class_embeddings = model.encode_text(texts) |
|
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) |
|
class_embedding = class_embeddings.mean(dim=0) |
|
class_embedding /= class_embedding.norm() |
|
zeroshot_weights.append(class_embedding) |
|
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() |
|
return zeroshot_weights |
|
|
|
|
|
def GetDt(classnames,model): |
|
text_features=zeroshot_classifier(classnames, imagenet_templates,model).t() |
|
|
|
dt=text_features[0]-text_features[1] |
|
dt=dt.cpu().numpy() |
|
|
|
|
|
|
|
print(np.linalg.norm(dt)) |
|
dt=dt/np.linalg.norm(dt) |
|
return dt |
|
|
|
|
|
def GetBoundary(fs3,dt,M,threshold): |
|
tmp=np.dot(fs3,dt) |
|
|
|
ds_imp=copy.copy(tmp) |
|
select=np.abs(tmp)<threshold |
|
num_c=np.sum(~select) |
|
|
|
|
|
ds_imp[select]=0 |
|
tmp=np.abs(ds_imp).max() |
|
ds_imp/=tmp |
|
|
|
boundary_tmp2=SplitS(ds_imp,M,if_std=True) |
|
print('num of channels being manipulated:',num_c) |
|
return boundary_tmp2,num_c |
|
|
|
def GetFs(file_path): |
|
fs=np.load(file_path+'single_channel.npy') |
|
tmp=np.linalg.norm(fs,axis=-1) |
|
fs1=fs/tmp[:,:,:,None] |
|
fs2=fs1[:,:,1,:]-fs1[:,:,0,:] |
|
fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None] |
|
fs3=fs3.mean(axis=1) |
|
fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None] |
|
return fs3 |
|
|
|
|
|
if __name__ == "__main__": |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model, preprocess = clip.load("ViT-B/32", device=device) |
|
|
|
sys.path.append('/cs/labs/danix/wuzongze/Gan_Manipulation/play') |
|
from example_try import Manipulator4 |
|
|
|
M=Manipulator4(dataset_name='ffhq',code_type='S') |
|
np.set_printoptions(suppress=True) |
|
|
|
|
|
|
|
|
|
file_path='/cs/labs/danix/wuzongze/Tansformer_Manipulation/CLIP/results/'+M.dataset_name+'/' |
|
fs3=GetFs(file_path) |
|
|
|
|
|
|
|
|
|
''' |
|
text_features=zeroshot_classifier2(classnames, imagenet_templates) #.t() |
|
|
|
tmp=np.linalg.norm(text_features,axis=2) |
|
text_features/=tmp[:,:,None] |
|
dt=text_features[0]-text_features[1] |
|
|
|
tmp=np.linalg.norm(dt,axis=1) |
|
dt/=tmp[:,None] |
|
dt=dt.mean(axis=0) |
|
''' |
|
|
|
|
|
''' |
|
all_tmp=[] |
|
tmp=torch.load('/cs/labs/danix/wuzongze/downloads/harris_latent.pt') |
|
tmp=tmp.cpu().detach().numpy() #[:,:14,:] |
|
all_tmp.append(tmp) |
|
|
|
tmp=torch.load('/cs/labs/danix/wuzongze/downloads/ariana_latent.pt') |
|
tmp=tmp.cpu().detach().numpy() #[:,:14,:] |
|
all_tmp.append(tmp) |
|
|
|
tmp=torch.load('/cs/labs/danix/wuzongze/downloads/federer.pt') |
|
tmp=tmp.cpu().detach().numpy() #[:,:14,:] |
|
all_tmp.append(tmp) |
|
|
|
all_tmp=np.array(all_tmp)[:,0] |
|
|
|
dlatent_tmp=M.W2S(all_tmp) |
|
''' |
|
''' |
|
tmp=torch.load('/cs/labs/danix/wuzongze/downloads/all_cars.pt') |
|
tmp=tmp.cpu().detach().numpy()[:300] |
|
dlatent_tmp=M.W2S(tmp) |
|
''' |
|
''' |
|
tmp=torch.load('/cs/labs/danix/wuzongze/downloads/faces.pt') |
|
tmp=tmp.cpu().detach().numpy()[:100] |
|
dlatent_tmp=M.W2S(tmp) |
|
''' |
|
|
|
|
|
M.img_index=0 |
|
M.num_images=30 |
|
dlatent_tmp=[tmp[M.img_index:(M.img_index+M.num_images)] for tmp in M.dlatents] |
|
|
|
|
|
classnames=['face','face with glasses'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dt=GetDt(classnames,model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
boundary_tmp2=GetBoundary(fs3,dt,M,threshold=0.13) |
|
|
|
|
|
M.start_distance=-20 |
|
M.end_distance=20 |
|
M.step=7 |
|
|
|
codes=M.MSCode(dlatent_tmp,boundary_tmp2) |
|
out=M.GenerateImg(codes) |
|
M.Vis2(str('tmp'),'filter2',out) |
|
|
|
|
|
|
|
|
|
|
|
boundary_tmp3=copy.copy(boundary_tmp2) |
|
boundary_tmp4=copy.copy(boundary_tmp2) |
|
|
|
boundary_tmp2=copy.copy(boundary_tmp3) |
|
for i in range(len(boundary_tmp3)): |
|
select=boundary_tmp4[i]==0 |
|
boundary_tmp2[i][~select]=0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|