#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu Feb 4 17:36:31 2021 @author: wuzongze """ import os #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" #os.environ["CUDA_VISIBLE_DEVICES"] = "1" #(or "1" or "2") import sys #sys.path=['', '/usr/local/tensorflow/avx-avx2-gpu/1.14.0/python3.7/site-packages', '/usr/local/matlab/2018b/lib/python3.7/site-packages', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python37.zip', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/lib-dynload', '/usr/lib/python3.7', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages/copkmeans-1.5-py3.7.egg', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages/spherecluster-0.1.7-py3.7.egg', '/usr/lib/python3/dist-packages', '/usr/local/lib/python3.7/dist-packages', '/usr/lib/python3/dist-packages/IPython/extensions'] import tensorflow as tf import numpy as np import torch import clip from PIL import Image import pickle import copy import matplotlib.pyplot as plt def GetAlign(out,dt,model,preprocess): imgs=out imgs1=imgs.reshape([-1]+list(imgs.shape[2:])) tmp=[] for i in range(len(imgs1)): img=Image.fromarray(imgs1[i]) image = preprocess(img).unsqueeze(0).to(device) tmp.append(image) image=torch.cat(tmp) with torch.no_grad(): image_features = model.encode_image(image) image_features = image_features / image_features.norm(dim=-1, keepdim=True) image_features1=image_features.cpu().numpy() image_features1=image_features1.reshape(list(imgs.shape[:2])+[512]) fd=image_features1[:,1:,:]-image_features1[:,:-1,:] fd1=fd.reshape([-1,512]) fd2=fd1/np.linalg.norm(fd1,axis=1)[:,None] tmp=np.dot(fd2,dt) m=tmp.mean() acc=np.sum(tmp>0)/len(tmp) print(m,acc) return m,acc def SplitS(ds_p,M,if_std): all_ds=[] start=0 for i in M.mindexs: tmp=M.dlatents[i].shape[1] end=start+tmp tmp=ds_p[start:end] # tmp=tmp*M.code_std[i] all_ds.append(tmp) start=end all_ds2=[] tmp_index=0 for i in range(len(M.s_names)): if (not 'RGB' in M.s_names[i]) and (not len(all_ds[tmp_index])==0): # tmp=np.abs(all_ds[tmp_index]/M.code_std[i]) # print(i,tmp.mean()) # tmp=np.dot(M.latent_codes[i],all_ds[tmp_index]) # print(tmp) if if_std: tmp=all_ds[tmp_index]*M.code_std[i] else: tmp=all_ds[tmp_index] all_ds2.append(tmp) tmp_index+=1 else: tmp=np.zeros(len(M.dlatents[i][0])) all_ds2.append(tmp) return all_ds2 imagenet_templates = [ 'a bad photo of a {}.', # 'a photo of many {}.', 'a sculpture of a {}.', 'a photo of the hard to see {}.', 'a low resolution photo of the {}.', 'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.', 'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.', 'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.', 'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.', 'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.', 'a doodle of a {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.', 'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.', 'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.', 'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.', 'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.', 'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.', 'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.', 'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.', 'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.', 'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.', ] def zeroshot_classifier(classnames, templates,model): with torch.no_grad(): zeroshot_weights = [] for classname in classnames: texts = [template.format(classname) for template in templates] #format with class texts = clip.tokenize(texts).cuda() #tokenize class_embeddings = model.encode_text(texts) #embed with text encoder class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) class_embedding = class_embeddings.mean(dim=0) class_embedding /= class_embedding.norm() zeroshot_weights.append(class_embedding) zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() return zeroshot_weights def GetDt(classnames,model): text_features=zeroshot_classifier(classnames, imagenet_templates,model).t() dt=text_features[0]-text_features[1] dt=dt.cpu().numpy() # t_m1=t_m/np.linalg.norm(t_m) # dt=text_features.cpu().numpy()[0]-t_m1 print(np.linalg.norm(dt)) dt=dt/np.linalg.norm(dt) return dt def GetBoundary(fs3,dt,M,threshold): tmp=np.dot(fs3,dt) ds_imp=copy.copy(tmp) select=np.abs(tmp)