testing new latent space visualization code
Browse files
ContraCLIP/traverse_latent_space.py
CHANGED
@@ -10,6 +10,10 @@ from torchvision.transforms import ToPILImage
|
|
10 |
from lib import SupportSets, GENFORCE_MODELS, update_progress, update_stdout, STYLEGAN_LAYERS
|
11 |
from models.load_generator import load_generator
|
12 |
|
|
|
|
|
|
|
|
|
13 |
|
14 |
class DataParallelPassthrough(nn.DataParallel):
|
15 |
def __getattr__(self, name):
|
@@ -97,6 +101,63 @@ def create_gif(image_list, gif_height=256):
|
|
97 |
|
98 |
return transformed_images_gif_frames
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
def main():
|
102 |
"""ContraCLIP -- Latent space traversal script.
|
@@ -210,6 +271,8 @@ def main():
|
|
210 |
# -- Get prompt corpus list
|
211 |
with open(osp.join(models_dir, 'semantic_dipoles.json'), 'r') as f:
|
212 |
semantic_dipoles = json.load(f)
|
|
|
|
|
213 |
|
214 |
# Check given pool directory
|
215 |
pool = osp.join('experiments', 'latent_codes', gan, args.pool)
|
@@ -321,6 +384,9 @@ def main():
|
|
321 |
print(" \\__Shift steps : {}".format(2 * args.shift_steps))
|
322 |
print(" \\__Traversal length : {}".format(round(2 * args.shift_steps * args.eps, 3)))
|
323 |
|
|
|
|
|
|
|
324 |
# Iterate over given latent codes
|
325 |
for i in range(num_of_latent_codes):
|
326 |
# Get latent code
|
@@ -333,6 +399,9 @@ def main():
|
|
333 |
num_of_latent_codes),
|
334 |
num_of_latent_codes, i)
|
335 |
|
|
|
|
|
|
|
336 |
# Create directory for current latent code
|
337 |
latent_code_dir = osp.join(out_dir, '{}'.format(latent_code_hash))
|
338 |
os.makedirs(latent_code_dir, exist_ok=True)
|
@@ -386,7 +455,7 @@ def main():
|
|
386 |
latent_code = latent_code[:, 0, :]
|
387 |
|
388 |
cnt = 0
|
389 |
-
for
|
390 |
cnt += 1
|
391 |
|
392 |
# Calculate shift vector based on current z
|
@@ -410,6 +479,10 @@ def main():
|
|
410 |
latent_code = latent_code + shift
|
411 |
current_path_latent_code = latent_code
|
412 |
|
|
|
|
|
|
|
|
|
413 |
# Store latent codes and shifts
|
414 |
if cnt == args.shift_leap:
|
415 |
if ('stylegan' in gan) and (stylegan_space == 'W+'):
|
@@ -421,6 +494,8 @@ def main():
|
|
421 |
current_path_latent_codes.append(current_path_latent_code)
|
422 |
cnt = 0
|
423 |
positive_endpoint = latent_code.clone().reshape(1, -1)
|
|
|
|
|
424 |
# ========================
|
425 |
|
426 |
# == Negative direction ==
|
@@ -430,7 +505,7 @@ def main():
|
|
430 |
if stylegan_space == 'W':
|
431 |
latent_code = latent_code[:, 0, :]
|
432 |
cnt = 0
|
433 |
-
for
|
434 |
cnt += 1
|
435 |
# Calculate shift vector based on current z
|
436 |
support_sets_mask = torch.zeros(1, LSS.num_support_sets)
|
@@ -453,6 +528,10 @@ def main():
|
|
453 |
latent_code = latent_code + shift
|
454 |
current_path_latent_code = latent_code
|
455 |
|
|
|
|
|
|
|
|
|
456 |
# Store latent codes and shifts
|
457 |
if cnt == args.shift_leap:
|
458 |
if ('stylegan' in gan) and (stylegan_space == 'W+'):
|
@@ -464,6 +543,8 @@ def main():
|
|
464 |
current_path_latent_codes = [current_path_latent_code] + current_path_latent_codes
|
465 |
cnt = 0
|
466 |
negative_endpoint = latent_code.clone().reshape(1, -1)
|
|
|
|
|
467 |
# ========================
|
468 |
|
469 |
# Calculate latent path phi coefficient (end-to-end distance / latent path length)
|
@@ -531,13 +612,69 @@ def main():
|
|
531 |
|
532 |
# Save all latent paths and shifts for the current latent code (sample) in a tensor of size:
|
533 |
# paths_latent_codes : torch.Size([num_gen_paths, 2 * args.shift_steps + 1, G.dim_z])
|
534 |
-
torch.
|
|
|
|
|
535 |
|
536 |
if args.verbose:
|
537 |
update_stdout(1)
|
538 |
print()
|
539 |
print()
|
540 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
# Create summarizing MD files
|
542 |
if args.gif or args.strip:
|
543 |
# For each interpretable path (warping function), collect the generated image sequences for each original latent
|
|
|
10 |
from lib import SupportSets, GENFORCE_MODELS, update_progress, update_stdout, STYLEGAN_LAYERS
|
11 |
from models.load_generator import load_generator
|
12 |
|
13 |
+
import numpy as np
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
from mpl_toolkits.mplot3d import Axes3D
|
16 |
+
from sklearn.manifold import TSNE
|
17 |
|
18 |
class DataParallelPassthrough(nn.DataParallel):
|
19 |
def __getattr__(self, name):
|
|
|
101 |
|
102 |
return transformed_images_gif_frames
|
103 |
|
104 |
+
def visualize_latent_space(tsne_latent_codes, semantic_dipoles, output_dir, save_filename="latent_space_tsne.png", shift_steps=16):
|
105 |
+
"""
|
106 |
+
Visualize the t-SNE reduced latent space with minimal annotations.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
tsne_latent_codes (np.ndarray): The 3D latent codes after t-SNE transformation.
|
110 |
+
semantic_dipoles (list): List of semantic directions (labels) for paths.
|
111 |
+
shift_steps (int): Number of positive/negative steps along each path.
|
112 |
+
output_dir (str): Directory to save the generated plot.
|
113 |
+
save_filename (str): Name of the file to save the plot.
|
114 |
+
"""
|
115 |
+
fig = plt.figure(figsize=(16, 12)) # Larger figure for clarity
|
116 |
+
ax = fig.add_subplot(111, projection='3d')
|
117 |
+
|
118 |
+
num_paths = len(semantic_dipoles) # Each dipole represents one path
|
119 |
+
cmap = plt.cm.get_cmap('tab10', num_paths)
|
120 |
+
|
121 |
+
for i in range(num_paths):
|
122 |
+
# Indices for the path in tsne_latent_codes
|
123 |
+
start_idx = i * (2 * shift_steps + 1)
|
124 |
+
pos_idx = start_idx + shift_steps # Positive endpoint
|
125 |
+
neg_idx = start_idx + 2 * shift_steps # Negative endpoint
|
126 |
+
|
127 |
+
# Extract path points
|
128 |
+
path_indices = list(range(start_idx, neg_idx + 1))
|
129 |
+
path_coords = tsne_latent_codes[path_indices]
|
130 |
+
|
131 |
+
# Plot the entire path (all intermediate points in a single color)
|
132 |
+
ax.plot(
|
133 |
+
path_coords[:, 0], path_coords[:, 1], path_coords[:, 2],
|
134 |
+
color=cmap(i),
|
135 |
+
linewidth=2
|
136 |
+
)
|
137 |
+
|
138 |
+
# Extract positive and negative endpoint coordinates
|
139 |
+
pos_coords = tsne_latent_codes[pos_idx]
|
140 |
+
neg_coords = tsne_latent_codes[neg_idx]
|
141 |
+
|
142 |
+
# Plot positive and negative endpoints
|
143 |
+
ax.scatter(*pos_coords, color=cmap(i), s=100, label=f"{semantic_dipoles[i][0]} β {semantic_dipoles[i][1]}")
|
144 |
+
ax.scatter(*neg_coords, color=cmap(i), s=100)
|
145 |
+
|
146 |
+
# Add legend
|
147 |
+
ax.legend(loc='best', fontsize=10)
|
148 |
+
|
149 |
+
# Set titles and labels
|
150 |
+
ax.set_title("t-SNE Latent Space Visualization")
|
151 |
+
ax.set_xlabel("t-SNE Dimension 1")
|
152 |
+
ax.set_ylabel("t-SNE Dimension 2")
|
153 |
+
ax.set_zlabel("t-SNE Dimension 3")
|
154 |
+
|
155 |
+
# Save the plot
|
156 |
+
os.makedirs(output_dir, exist_ok=True)
|
157 |
+
save_path = osp.join(output_dir, save_filename)
|
158 |
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
159 |
+
print(f"Visualization saved to {save_path}")
|
160 |
+
|
161 |
|
162 |
def main():
|
163 |
"""ContraCLIP -- Latent space traversal script.
|
|
|
271 |
# -- Get prompt corpus list
|
272 |
with open(osp.join(models_dir, 'semantic_dipoles.json'), 'r') as f:
|
273 |
semantic_dipoles = json.load(f)
|
274 |
+
|
275 |
+
# semantic_directions = [f"{dipole[0]} β {dipole[1]}" for dipole in semantic_dipoles]
|
276 |
|
277 |
# Check given pool directory
|
278 |
pool = osp.join('experiments', 'latent_codes', gan, args.pool)
|
|
|
384 |
print(" \\__Shift steps : {}".format(2 * args.shift_steps))
|
385 |
print(" \\__Traversal length : {}".format(round(2 * args.shift_steps * args.eps, 3)))
|
386 |
|
387 |
+
# Store latent codes for T-SNE visualization (for all paths across each latent code)
|
388 |
+
all_paths_latent_codes = []
|
389 |
+
|
390 |
# Iterate over given latent codes
|
391 |
for i in range(num_of_latent_codes):
|
392 |
# Get latent code
|
|
|
399 |
num_of_latent_codes),
|
400 |
num_of_latent_codes, i)
|
401 |
|
402 |
+
# Append the starting latent code to tsne_latent_codes
|
403 |
+
# tsne_latent_codes.append(x_.clone().cpu().numpy().flatten())
|
404 |
+
|
405 |
# Create directory for current latent code
|
406 |
latent_code_dir = osp.join(out_dir, '{}'.format(latent_code_hash))
|
407 |
os.makedirs(latent_code_dir, exist_ok=True)
|
|
|
455 |
latent_code = latent_code[:, 0, :]
|
456 |
|
457 |
cnt = 0
|
458 |
+
for k in range(args.shift_steps):
|
459 |
cnt += 1
|
460 |
|
461 |
# Calculate shift vector based on current z
|
|
|
479 |
latent_code = latent_code + shift
|
480 |
current_path_latent_code = latent_code
|
481 |
|
482 |
+
# Append intermediate latent code
|
483 |
+
# if k != args.shift_steps - 1:
|
484 |
+
# tsne_latent_codes.append(latent_code.clone().cpu().numpy().flatten())
|
485 |
+
|
486 |
# Store latent codes and shifts
|
487 |
if cnt == args.shift_leap:
|
488 |
if ('stylegan' in gan) and (stylegan_space == 'W+'):
|
|
|
494 |
current_path_latent_codes.append(current_path_latent_code)
|
495 |
cnt = 0
|
496 |
positive_endpoint = latent_code.clone().reshape(1, -1)
|
497 |
+
|
498 |
+
# tsne_latent_codes.append(positive_endpoint.clone().cpu().numpy().flatten())
|
499 |
# ========================
|
500 |
|
501 |
# == Negative direction ==
|
|
|
505 |
if stylegan_space == 'W':
|
506 |
latent_code = latent_code[:, 0, :]
|
507 |
cnt = 0
|
508 |
+
for k in range(args.shift_steps):
|
509 |
cnt += 1
|
510 |
# Calculate shift vector based on current z
|
511 |
support_sets_mask = torch.zeros(1, LSS.num_support_sets)
|
|
|
528 |
latent_code = latent_code + shift
|
529 |
current_path_latent_code = latent_code
|
530 |
|
531 |
+
# Append intermediate latent code
|
532 |
+
# if k != args.shift_steps - 1:
|
533 |
+
# tsne_latent_codes.append(latent_code.clone().cpu().numpy().flatten())
|
534 |
+
|
535 |
# Store latent codes and shifts
|
536 |
if cnt == args.shift_leap:
|
537 |
if ('stylegan' in gan) and (stylegan_space == 'W+'):
|
|
|
543 |
current_path_latent_codes = [current_path_latent_code] + current_path_latent_codes
|
544 |
cnt = 0
|
545 |
negative_endpoint = latent_code.clone().reshape(1, -1)
|
546 |
+
|
547 |
+
# tsne_latent_codes.append(latent_code.clone().cpu().numpy().flatten())
|
548 |
# ========================
|
549 |
|
550 |
# Calculate latent path phi coefficient (end-to-end distance / latent path length)
|
|
|
612 |
|
613 |
# Save all latent paths and shifts for the current latent code (sample) in a tensor of size:
|
614 |
# paths_latent_codes : torch.Size([num_gen_paths, 2 * args.shift_steps + 1, G.dim_z])
|
615 |
+
paths_latent_codes_tensor = torch.cat(paths_latent_codes)
|
616 |
+
torch.save(paths_latent_codes_tensor, osp.join(latent_code_dir, 'paths_latent_codes.pt'))
|
617 |
+
all_paths_latent_codes.append(paths_latent_codes_tensor.cpu().numpy())
|
618 |
|
619 |
if args.verbose:
|
620 |
update_stdout(1)
|
621 |
print()
|
622 |
print()
|
623 |
|
624 |
+
# After processing all latent codes and paths
|
625 |
+
if args.verbose:
|
626 |
+
print("Performing t-SNE on latent codes for visualization...")
|
627 |
+
|
628 |
+
# # Consolidate all paths for T-SNE visualization (total_paths = num_of_latent_codes * num_gen_paths)
|
629 |
+
# all_paths_np = np.concatenate(all_paths_latent_codes, axis=0) # Shape: [total_paths, steps_per_path, latent_dim]
|
630 |
+
# all_paths_flattened = all_paths_np.reshape(-1, all_paths_np.shape[-1]) # Flatten paths into 2D array for T-SNE
|
631 |
+
|
632 |
+
# # Apply 3D T-SNE
|
633 |
+
# tsne_model = TSNE(n_components=3, perplexity=30, learning_rate=200, random_state=42)
|
634 |
+
# tsne_transformed = tsne_model.fit_transform(all_paths_flattened) # Shape: [total_points, 3]
|
635 |
+
|
636 |
+
# path_indices = [] # List to store indices for each path
|
637 |
+
# start_idx = 0 # Starting index for the current path in all_paths_np
|
638 |
+
|
639 |
+
# steps_per_path = 2 * args.shift_steps + 1 # Number of points in each path
|
640 |
+
|
641 |
+
# # Iterate over each latent code and its paths
|
642 |
+
# for i in range(num_of_latent_codes): # Loop through latent codes
|
643 |
+
# for dim in range(num_gen_paths): # Loop through directions (paths)
|
644 |
+
# # Generate the indices for this path
|
645 |
+
# indices = list(range(start_idx, start_idx + steps_per_path))
|
646 |
+
# path_indices.append(indices)
|
647 |
+
|
648 |
+
# # Update the starting index for the next path
|
649 |
+
# start_idx += steps_per_path
|
650 |
+
|
651 |
+
|
652 |
+
all_paths_latent_code_0 = all_paths_latent_codes[0]
|
653 |
+
num_paths, num_steps, _ = all_paths_latent_code_0.shape
|
654 |
+
tsne_latent_codes = all_paths_latent_code_0.reshape(-1, all_paths_latent_code_0.shape[-1])
|
655 |
+
|
656 |
+
# Apply 3D T-SNE
|
657 |
+
tsne_model = TSNE(n_components=3, perplexity=30, learning_rate=200, random_state=42)
|
658 |
+
tsne_transformed = tsne_model.fit_transform(tsne_latent_codes) # Shape: [total_points = num_paths * num_steps, 3]
|
659 |
+
|
660 |
+
# For this specific latent code, generate indices for each of its paths
|
661 |
+
path_indices = []
|
662 |
+
start_idx = 0
|
663 |
+
for _ in range(num_paths):
|
664 |
+
indices = list(range(start_idx, start_idx + num_steps))
|
665 |
+
path_indices.append(indices)
|
666 |
+
start_idx += num_steps
|
667 |
+
|
668 |
+
|
669 |
+
tsne_vis_dir = osp.join(out_dir, 'tsne_visualizations')
|
670 |
+
visualize_latent_space(
|
671 |
+
tsne_latent_codes=tsne_transformed, # T-SNE-reduced latent codes
|
672 |
+
semantic_dipoles=semantic_dipoles, # Semantic labels for paths
|
673 |
+
paths=path_indices, # Indices of paths (for a single latent code)
|
674 |
+
output_dir=tsne_vis_dir,
|
675 |
+
save_filename="latent_space_tsne.png"
|
676 |
+
)
|
677 |
+
|
678 |
# Create summarizing MD files
|
679 |
if args.gif or args.strip:
|
680 |
# For each interpretable path (warping function), collect the generated image sequences for each original latent
|