ludusc commited on
Commit
c3c13cc
·
1 Parent(s): 13f53a4

dtypes weird stuff

Browse files
Files changed (1) hide show
  1. backend/disentangle_concepts.py +8 -6
backend/disentangle_concepts.py CHANGED
@@ -84,14 +84,16 @@ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3
84
  for _, lambda_ in enumerate(lambdas):
85
  z_0 = z + lambda_ * decision_boundary
86
  if latent_space == 'Z':
87
- W_0 = G.mapping(z_0, label, truncation_psi=1)
88
- W = G.mapping(z, label, truncation_psi=1)
 
89
  else:
90
- W_0 = z_0.expand((14, -1)).unsqueeze(0)
91
- W = z.expand((14, -1)).unsqueeze(0)
 
92
 
93
  if layers:
94
- W_f = torch.empty_like(W).copy_(W)
95
  W_f[:, layers, :] = W_0[:, layers, :]
96
  img = G.synthesis(W_f, noise_mode='const')
97
  else:
@@ -125,7 +127,7 @@ def generate_original_image(z, model, latent_space='Z'):
125
  label = torch.zeros([1, G.c_dim], device=device)
126
  if latent_space == 'Z':
127
  z = torch.from_numpy(z.copy()).to(device)
128
- img = G(z, label, truncation_psi=0.7, noise_mode='const')
129
  else:
130
  W = torch.from_numpy(np.repeat(z, 14, axis=0).reshape(1, 14, z.shape[1]).copy()).to(device)
131
  print(W.shape)
 
84
  for _, lambda_ in enumerate(lambdas):
85
  z_0 = z + lambda_ * decision_boundary
86
  if latent_space == 'Z':
87
+ W_0 = G.mapping(z_0, label, truncation_psi=1).to(torch.float32)
88
+ W = G.mapping(z, label, truncation_psi=1).to(torch.float32)
89
+ print(W.dtype)
90
  else:
91
+ W_0 = z_0.expand((14, -1)).unsqueeze(0).to(torch.float32)
92
+ W = z.expand((14, -1)).unsqueeze(0).to(torch.float32)
93
+ print(W.dtype)
94
 
95
  if layers:
96
+ W_f = torch.empty_like(W).copy_(W).to(torch.float32)
97
  W_f[:, layers, :] = W_0[:, layers, :]
98
  img = G.synthesis(W_f, noise_mode='const')
99
  else:
 
127
  label = torch.zeros([1, G.c_dim], device=device)
128
  if latent_space == 'Z':
129
  z = torch.from_numpy(z.copy()).to(device)
130
+ img = G(z, label, truncation_psi=1, noise_mode='const')
131
  else:
132
  W = torch.from_numpy(np.repeat(z, 14, axis=0).reshape(1, 14, z.shape[1]).copy()).to(device)
133
  print(W.shape)