latent-space-theories / backend /disentangle_concepts.py
ludusc's picture
added changed images
266a83c
raw
history blame
2.42 kB
import numpy as np
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
import torch
import PIL
def get_separation_space(type_bin, annotations, df):
abstracts = np.array([float(ann) for ann in df[type_bin]])
abstract_idxs = list(np.argsort(abstracts))[:200]
repr_idxs = list(np.argsort(abstracts))[-200:]
X = np.array([annotations['z_vectors'][i] for i in abstract_idxs+repr_idxs])
X = X.reshape((400, 512))
y = np.array([1]*200 + [0]*200)
x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
svc = SVC(gamma='auto', kernel='linear')
svc.fit(x_train, y_train)
print(svc.score(x_val, y_val))
imp_features = (np.abs(svc.coef_) > 0.1).sum()
return svc.coef_, imp_features
def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=5):
device = torch.device('cpu')
G = model.to(device) # type: ignore
# Labels.
label = torch.zeros([1, G.c_dim], device=device)
z = torch.from_numpy(z.copy()).to(device)
decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
lambdas = np.linspace(min_epsilon, max_epsilon, count)
images = []
# Generate images.
for _, lambda_ in enumerate(lambdas):
z_0 = z + lambda_ * decision_boundary
# Construct an inverse rotation/translation matrix and pass to the generator. The
# generator expects this matrix as an inverse to avoid potentially failing numerical
# operations in the network.
#if hasattr(G.synthesis, 'input'):
#m = make_transform(translate, rotate)
#m = np.linalg.inv(m)
#G.synthesis.input.transform.copy_(torch.from_numpy(m))
img = G(z_0, label, truncation_psi=0.7, noise_mode='random')
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
images.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
return images, lambdas
def generate_original_image(z, model):
device = torch.device('cpu')
G = model.to(device) # type: ignore
# Labels.
label = torch.zeros([1, G.c_dim], device=device)
z = torch.from_numpy(z.copy()).to(device)
img = G(z, label, truncation_psi=0.7, noise_mode='random')
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')