dattarij commited on
Commit
5e94425
β€’
1 Parent(s): a411d5c

testing new latent space visualization code

Browse files
Files changed (1) hide show
  1. ContraCLIP/traverse_latent_space.py +140 -3
ContraCLIP/traverse_latent_space.py CHANGED
@@ -10,6 +10,10 @@ from torchvision.transforms import ToPILImage
10
  from lib import SupportSets, GENFORCE_MODELS, update_progress, update_stdout, STYLEGAN_LAYERS
11
  from models.load_generator import load_generator
12
 
 
 
 
 
13
 
14
  class DataParallelPassthrough(nn.DataParallel):
15
  def __getattr__(self, name):
@@ -97,6 +101,63 @@ def create_gif(image_list, gif_height=256):
97
 
98
  return transformed_images_gif_frames
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  def main():
102
  """ContraCLIP -- Latent space traversal script.
@@ -210,6 +271,8 @@ def main():
210
  # -- Get prompt corpus list
211
  with open(osp.join(models_dir, 'semantic_dipoles.json'), 'r') as f:
212
  semantic_dipoles = json.load(f)
 
 
213
 
214
  # Check given pool directory
215
  pool = osp.join('experiments', 'latent_codes', gan, args.pool)
@@ -321,6 +384,9 @@ def main():
321
  print(" \\__Shift steps : {}".format(2 * args.shift_steps))
322
  print(" \\__Traversal length : {}".format(round(2 * args.shift_steps * args.eps, 3)))
323
 
 
 
 
324
  # Iterate over given latent codes
325
  for i in range(num_of_latent_codes):
326
  # Get latent code
@@ -333,6 +399,9 @@ def main():
333
  num_of_latent_codes),
334
  num_of_latent_codes, i)
335
 
 
 
 
336
  # Create directory for current latent code
337
  latent_code_dir = osp.join(out_dir, '{}'.format(latent_code_hash))
338
  os.makedirs(latent_code_dir, exist_ok=True)
@@ -386,7 +455,7 @@ def main():
386
  latent_code = latent_code[:, 0, :]
387
 
388
  cnt = 0
389
- for _ in range(args.shift_steps):
390
  cnt += 1
391
 
392
  # Calculate shift vector based on current z
@@ -410,6 +479,10 @@ def main():
410
  latent_code = latent_code + shift
411
  current_path_latent_code = latent_code
412
 
 
 
 
 
413
  # Store latent codes and shifts
414
  if cnt == args.shift_leap:
415
  if ('stylegan' in gan) and (stylegan_space == 'W+'):
@@ -421,6 +494,8 @@ def main():
421
  current_path_latent_codes.append(current_path_latent_code)
422
  cnt = 0
423
  positive_endpoint = latent_code.clone().reshape(1, -1)
 
 
424
  # ========================
425
 
426
  # == Negative direction ==
@@ -430,7 +505,7 @@ def main():
430
  if stylegan_space == 'W':
431
  latent_code = latent_code[:, 0, :]
432
  cnt = 0
433
- for _ in range(args.shift_steps):
434
  cnt += 1
435
  # Calculate shift vector based on current z
436
  support_sets_mask = torch.zeros(1, LSS.num_support_sets)
@@ -453,6 +528,10 @@ def main():
453
  latent_code = latent_code + shift
454
  current_path_latent_code = latent_code
455
 
 
 
 
 
456
  # Store latent codes and shifts
457
  if cnt == args.shift_leap:
458
  if ('stylegan' in gan) and (stylegan_space == 'W+'):
@@ -464,6 +543,8 @@ def main():
464
  current_path_latent_codes = [current_path_latent_code] + current_path_latent_codes
465
  cnt = 0
466
  negative_endpoint = latent_code.clone().reshape(1, -1)
 
 
467
  # ========================
468
 
469
  # Calculate latent path phi coefficient (end-to-end distance / latent path length)
@@ -531,13 +612,69 @@ def main():
531
 
532
  # Save all latent paths and shifts for the current latent code (sample) in a tensor of size:
533
  # paths_latent_codes : torch.Size([num_gen_paths, 2 * args.shift_steps + 1, G.dim_z])
534
- torch.save(torch.cat(paths_latent_codes), osp.join(latent_code_dir, 'paths_latent_codes.pt'))
 
 
535
 
536
  if args.verbose:
537
  update_stdout(1)
538
  print()
539
  print()
540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
  # Create summarizing MD files
542
  if args.gif or args.strip:
543
  # For each interpretable path (warping function), collect the generated image sequences for each original latent
 
10
  from lib import SupportSets, GENFORCE_MODELS, update_progress, update_stdout, STYLEGAN_LAYERS
11
  from models.load_generator import load_generator
12
 
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ from mpl_toolkits.mplot3d import Axes3D
16
+ from sklearn.manifold import TSNE
17
 
18
  class DataParallelPassthrough(nn.DataParallel):
19
  def __getattr__(self, name):
 
101
 
102
  return transformed_images_gif_frames
103
 
104
+ def visualize_latent_space(tsne_latent_codes, semantic_dipoles, output_dir, save_filename="latent_space_tsne.png", shift_steps=16):
105
+ """
106
+ Visualize the t-SNE reduced latent space with minimal annotations.
107
+
108
+ Args:
109
+ tsne_latent_codes (np.ndarray): The 3D latent codes after t-SNE transformation.
110
+ semantic_dipoles (list): List of semantic directions (labels) for paths.
111
+ shift_steps (int): Number of positive/negative steps along each path.
112
+ output_dir (str): Directory to save the generated plot.
113
+ save_filename (str): Name of the file to save the plot.
114
+ """
115
+ fig = plt.figure(figsize=(16, 12)) # Larger figure for clarity
116
+ ax = fig.add_subplot(111, projection='3d')
117
+
118
+ num_paths = len(semantic_dipoles) # Each dipole represents one path
119
+ cmap = plt.cm.get_cmap('tab10', num_paths)
120
+
121
+ for i in range(num_paths):
122
+ # Indices for the path in tsne_latent_codes
123
+ start_idx = i * (2 * shift_steps + 1)
124
+ pos_idx = start_idx + shift_steps # Positive endpoint
125
+ neg_idx = start_idx + 2 * shift_steps # Negative endpoint
126
+
127
+ # Extract path points
128
+ path_indices = list(range(start_idx, neg_idx + 1))
129
+ path_coords = tsne_latent_codes[path_indices]
130
+
131
+ # Plot the entire path (all intermediate points in a single color)
132
+ ax.plot(
133
+ path_coords[:, 0], path_coords[:, 1], path_coords[:, 2],
134
+ color=cmap(i),
135
+ linewidth=2
136
+ )
137
+
138
+ # Extract positive and negative endpoint coordinates
139
+ pos_coords = tsne_latent_codes[pos_idx]
140
+ neg_coords = tsne_latent_codes[neg_idx]
141
+
142
+ # Plot positive and negative endpoints
143
+ ax.scatter(*pos_coords, color=cmap(i), s=100, label=f"{semantic_dipoles[i][0]} β†’ {semantic_dipoles[i][1]}")
144
+ ax.scatter(*neg_coords, color=cmap(i), s=100)
145
+
146
+ # Add legend
147
+ ax.legend(loc='best', fontsize=10)
148
+
149
+ # Set titles and labels
150
+ ax.set_title("t-SNE Latent Space Visualization")
151
+ ax.set_xlabel("t-SNE Dimension 1")
152
+ ax.set_ylabel("t-SNE Dimension 2")
153
+ ax.set_zlabel("t-SNE Dimension 3")
154
+
155
+ # Save the plot
156
+ os.makedirs(output_dir, exist_ok=True)
157
+ save_path = osp.join(output_dir, save_filename)
158
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
159
+ print(f"Visualization saved to {save_path}")
160
+
161
 
162
  def main():
163
  """ContraCLIP -- Latent space traversal script.
 
271
  # -- Get prompt corpus list
272
  with open(osp.join(models_dir, 'semantic_dipoles.json'), 'r') as f:
273
  semantic_dipoles = json.load(f)
274
+
275
+ # semantic_directions = [f"{dipole[0]} β†’ {dipole[1]}" for dipole in semantic_dipoles]
276
 
277
  # Check given pool directory
278
  pool = osp.join('experiments', 'latent_codes', gan, args.pool)
 
384
  print(" \\__Shift steps : {}".format(2 * args.shift_steps))
385
  print(" \\__Traversal length : {}".format(round(2 * args.shift_steps * args.eps, 3)))
386
 
387
+ # Store latent codes for T-SNE visualization (for all paths across each latent code)
388
+ all_paths_latent_codes = []
389
+
390
  # Iterate over given latent codes
391
  for i in range(num_of_latent_codes):
392
  # Get latent code
 
399
  num_of_latent_codes),
400
  num_of_latent_codes, i)
401
 
402
+ # Append the starting latent code to tsne_latent_codes
403
+ # tsne_latent_codes.append(x_.clone().cpu().numpy().flatten())
404
+
405
  # Create directory for current latent code
406
  latent_code_dir = osp.join(out_dir, '{}'.format(latent_code_hash))
407
  os.makedirs(latent_code_dir, exist_ok=True)
 
455
  latent_code = latent_code[:, 0, :]
456
 
457
  cnt = 0
458
+ for k in range(args.shift_steps):
459
  cnt += 1
460
 
461
  # Calculate shift vector based on current z
 
479
  latent_code = latent_code + shift
480
  current_path_latent_code = latent_code
481
 
482
+ # Append intermediate latent code
483
+ # if k != args.shift_steps - 1:
484
+ # tsne_latent_codes.append(latent_code.clone().cpu().numpy().flatten())
485
+
486
  # Store latent codes and shifts
487
  if cnt == args.shift_leap:
488
  if ('stylegan' in gan) and (stylegan_space == 'W+'):
 
494
  current_path_latent_codes.append(current_path_latent_code)
495
  cnt = 0
496
  positive_endpoint = latent_code.clone().reshape(1, -1)
497
+
498
+ # tsne_latent_codes.append(positive_endpoint.clone().cpu().numpy().flatten())
499
  # ========================
500
 
501
  # == Negative direction ==
 
505
  if stylegan_space == 'W':
506
  latent_code = latent_code[:, 0, :]
507
  cnt = 0
508
+ for k in range(args.shift_steps):
509
  cnt += 1
510
  # Calculate shift vector based on current z
511
  support_sets_mask = torch.zeros(1, LSS.num_support_sets)
 
528
  latent_code = latent_code + shift
529
  current_path_latent_code = latent_code
530
 
531
+ # Append intermediate latent code
532
+ # if k != args.shift_steps - 1:
533
+ # tsne_latent_codes.append(latent_code.clone().cpu().numpy().flatten())
534
+
535
  # Store latent codes and shifts
536
  if cnt == args.shift_leap:
537
  if ('stylegan' in gan) and (stylegan_space == 'W+'):
 
543
  current_path_latent_codes = [current_path_latent_code] + current_path_latent_codes
544
  cnt = 0
545
  negative_endpoint = latent_code.clone().reshape(1, -1)
546
+
547
+ # tsne_latent_codes.append(latent_code.clone().cpu().numpy().flatten())
548
  # ========================
549
 
550
  # Calculate latent path phi coefficient (end-to-end distance / latent path length)
 
612
 
613
  # Save all latent paths and shifts for the current latent code (sample) in a tensor of size:
614
  # paths_latent_codes : torch.Size([num_gen_paths, 2 * args.shift_steps + 1, G.dim_z])
615
+ paths_latent_codes_tensor = torch.cat(paths_latent_codes)
616
+ torch.save(paths_latent_codes_tensor, osp.join(latent_code_dir, 'paths_latent_codes.pt'))
617
+ all_paths_latent_codes.append(paths_latent_codes_tensor.cpu().numpy())
618
 
619
  if args.verbose:
620
  update_stdout(1)
621
  print()
622
  print()
623
 
624
+ # After processing all latent codes and paths
625
+ if args.verbose:
626
+ print("Performing t-SNE on latent codes for visualization...")
627
+
628
+ # # Consolidate all paths for T-SNE visualization (total_paths = num_of_latent_codes * num_gen_paths)
629
+ # all_paths_np = np.concatenate(all_paths_latent_codes, axis=0) # Shape: [total_paths, steps_per_path, latent_dim]
630
+ # all_paths_flattened = all_paths_np.reshape(-1, all_paths_np.shape[-1]) # Flatten paths into 2D array for T-SNE
631
+
632
+ # # Apply 3D T-SNE
633
+ # tsne_model = TSNE(n_components=3, perplexity=30, learning_rate=200, random_state=42)
634
+ # tsne_transformed = tsne_model.fit_transform(all_paths_flattened) # Shape: [total_points, 3]
635
+
636
+ # path_indices = [] # List to store indices for each path
637
+ # start_idx = 0 # Starting index for the current path in all_paths_np
638
+
639
+ # steps_per_path = 2 * args.shift_steps + 1 # Number of points in each path
640
+
641
+ # # Iterate over each latent code and its paths
642
+ # for i in range(num_of_latent_codes): # Loop through latent codes
643
+ # for dim in range(num_gen_paths): # Loop through directions (paths)
644
+ # # Generate the indices for this path
645
+ # indices = list(range(start_idx, start_idx + steps_per_path))
646
+ # path_indices.append(indices)
647
+
648
+ # # Update the starting index for the next path
649
+ # start_idx += steps_per_path
650
+
651
+
652
+ all_paths_latent_code_0 = all_paths_latent_codes[0]
653
+ num_paths, num_steps, _ = all_paths_latent_code_0.shape
654
+ tsne_latent_codes = all_paths_latent_code_0.reshape(-1, all_paths_latent_code_0.shape[-1])
655
+
656
+ # Apply 3D T-SNE
657
+ tsne_model = TSNE(n_components=3, perplexity=30, learning_rate=200, random_state=42)
658
+ tsne_transformed = tsne_model.fit_transform(tsne_latent_codes) # Shape: [total_points = num_paths * num_steps, 3]
659
+
660
+ # For this specific latent code, generate indices for each of its paths
661
+ path_indices = []
662
+ start_idx = 0
663
+ for _ in range(num_paths):
664
+ indices = list(range(start_idx, start_idx + num_steps))
665
+ path_indices.append(indices)
666
+ start_idx += num_steps
667
+
668
+
669
+ tsne_vis_dir = osp.join(out_dir, 'tsne_visualizations')
670
+ visualize_latent_space(
671
+ tsne_latent_codes=tsne_transformed, # T-SNE-reduced latent codes
672
+ semantic_dipoles=semantic_dipoles, # Semantic labels for paths
673
+ paths=path_indices, # Indices of paths (for a single latent code)
674
+ output_dir=tsne_vis_dir,
675
+ save_filename="latent_space_tsne.png"
676
+ )
677
+
678
  # Create summarizing MD files
679
  if args.gif or args.strip:
680
  # For each interpretable path (warping function), collect the generated image sequences for each original latent