AbstractPhil commited on
Commit
70328ac
Β·
verified Β·
1 Parent(s): b207f78

Update svd_triton_gram_newton.py

Browse files
Files changed (1) hide show
  1. svd_triton_gram_newton.py +158 -95
svd_triton_gram_newton.py CHANGED
@@ -492,25 +492,23 @@ def projected_svd_quality(A, target_rank=24):
492
 
493
 
494
  def procrustes_alignment_quality(N=48, k=24, n_samples=5000):
495
- """The real test: does rank-k Procrustes produce the same alignment as rank-N?
496
-
497
- Simulates the actual use case:
498
- 1. Generate two embedding spaces (source, target) with shared structure
499
- 2. Align with full N-d Procrustes
500
- 3. Align with projected k-d Procrustes
501
- 4. Compare: rotation agreement, aligned embedding cosine, downstream task impact
502
-
503
- Returns dict of comparison metrics.
504
  """
505
  device = 'cuda'
506
 
507
  # Create two embedding spaces with shared low-rank structure + noise
508
- # This simulates two expert encoders that agree on major directions
509
- shared_rank = min(N // 2, 32) # true shared structure
510
  shared_basis = torch.randn(shared_rank, N, device=device)
511
- shared_basis = torch.linalg.qr(shared_basis.T).Q.T # orthonormal rows
512
 
513
- # Source and target share the basis but with different coefficients + noise
514
  coeffs_src = torch.randn(n_samples, shared_rank, device=device)
515
  coeffs_tgt = torch.randn(n_samples, shared_rank, device=device) * 0.8 + coeffs_src * 0.5
516
  noise_scale = 0.3
@@ -518,119 +516,184 @@ def procrustes_alignment_quality(N=48, k=24, n_samples=5000):
518
  source = coeffs_src @ shared_basis + noise_scale * torch.randn(n_samples, N, device=device)
519
  target = coeffs_tgt @ shared_basis + noise_scale * torch.randn(n_samples, N, device=device)
520
 
521
- # Center
522
  source = source - source.mean(0, keepdim=True)
523
  target = target - target.mean(0, keepdim=True)
524
 
525
- # ═══ Full N-d Procrustes ═══
526
- # Cross-covariance β†’ SVD β†’ rotation
527
- C_full = source.T @ target # (N, N)
528
- U_f, S_f, Vh_f = torch.linalg.svd(C_full)
529
- R_full = U_f @ Vh_f # (N, N) optimal rotation
530
-
531
  aligned_full = source @ R_full
532
- cos_full = F.cosine_similarity(aligned_full, target, dim=-1) # (n_samples,)
533
 
534
  # ═══ Projected k-d Procrustes ═══
535
- # Project both spaces to k-d, align there, lift back
536
- # Random projection
537
  P = torch.randn(N, k, device=device) / math.sqrt(k)
538
-
539
- src_proj = source @ P # (n_samples, k)
540
- tgt_proj = target @ P # (n_samples, k)
541
-
542
- C_proj = src_proj.T @ tgt_proj # (k, k)
543
- U_p, S_p, Vh_p = torch.linalg.svd(C_proj)
544
- R_proj_k = U_p @ Vh_p # (k, k) rotation in projected space
545
-
546
- # Lift rotation back to N-d: R_N = P @ R_k @ P^T (pseudoinverse)
547
- # More precisely: align in projected space, then evaluate in full space
548
- aligned_proj_k = src_proj @ R_proj_k # aligned in k-d
549
- # Lift back: find best N-d rotation that maps source to target
550
- # using only the k-d alignment as guidance
551
- # R_lifted = P @ R_k @ pinv(P)
552
- P_pinv = torch.linalg.pinv(P) # (k, N)
553
- R_lifted = P @ R_proj_k @ P_pinv # (N, N)
554
- aligned_lifted = source @ R_lifted
555
- cos_lifted = F.cosine_similarity(aligned_lifted, target, dim=-1)
556
-
557
- # ═══ Also test: Procrustes in k-d only (don't lift, compare in k-d) ═══
558
- cos_proj_space = F.cosine_similarity(aligned_proj_k, tgt_proj, dim=-1)
559
- # Reference: full Procrustes projected to k-d
560
- aligned_full_proj = aligned_full @ P
561
- cos_full_proj = F.cosine_similarity(aligned_full_proj, tgt_proj, dim=-1)
562
-
563
- # ═══ Rotation agreement: how similar are R_full and R_lifted? ═══
564
- # Frobenius norm of difference
565
- rot_frob = (R_full - R_lifted).norm().item() / (R_full.norm().item() + 1e-8)
566
- # Trace agreement: tr(R_full^T @ R_lifted) / N β€” 1.0 if identical
567
- rot_trace = (R_full.T @ R_lifted).trace().item() / N
568
-
569
- # ═══ Downstream proxy: classification agreement ═══
570
- # If we classify by nearest-neighbor in aligned space, do both agree?
571
- # Use first 100 as "anchors", rest as queries
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
  n_anchor = min(100, n_samples // 2)
573
- anchors_full = aligned_full[:n_anchor]
574
- anchors_lift = aligned_lifted[:n_anchor]
575
- queries_full = aligned_full[n_anchor:]
576
- queries_lift = aligned_lifted[n_anchor:]
577
 
578
- nn_full = (queries_full @ anchors_full.T).argmax(-1)
579
- nn_lift = (queries_lift @ anchors_lift.T).argmax(-1)
580
- nn_agreement = (nn_full == nn_lift).float().mean().item()
 
 
 
 
 
 
 
 
581
 
582
  return {
583
  'N': N, 'k': k,
584
- 'cos_full_mean': cos_full.mean().item(), # alignment quality: full Procrustes
585
- 'cos_lifted_mean': cos_lifted.mean().item(), # alignment quality: projected + lifted
586
- 'cos_proj_space': cos_proj_space.mean().item(), # alignment in k-d only
587
- 'cos_full_proj': cos_full_proj.mean().item(), # full Procrustes seen from k-d
588
- 'rot_frob_rel': rot_frob, # rotation matrix difference (relative)
589
- 'rot_trace_norm': rot_trace, # rotation trace agreement (1.0 = perfect)
590
- 'nn_agreement': nn_agreement, # nearest-neighbor classification agreement
 
 
591
  }
592
 
593
 
594
  def profile_procrustes_quality():
595
- """Compare Procrustes alignment quality: full N-d vs projected k-d."""
596
- print(f"\n{'='*100}")
597
- print(f" PROCRUSTES ALIGNMENT QUALITY: full N-d vs projected k-d")
598
- print(f" Does rank-k alignment produce the same rotation as rank-N?")
599
- print(f"{'='*100}")
 
600
 
601
  configs = [
602
- (32, [8, 12, 16, 24]),
603
- (48, [8, 12, 16, 24, 32]),
604
- (64, [8, 12, 16, 24, 32]),
605
- (96, [8, 16, 24, 32, 48]),
606
- (128, [8, 16, 24, 32, 48, 64]),
607
  ]
608
 
609
  all_results = []
610
 
611
  for N, ranks in configs:
612
  print(f"\n N={N}:")
613
- print(f" {'k':>5} {'cos_full':>9} {'cos_lifted':>11} {'cos_k-d':>9}"
614
- f" {'rot_trace':>10} {'rot_frob':>10} {'NN_agree':>9}")
615
- print(f" {'─'*76}")
 
616
 
617
  for k in ranks:
618
  if k >= N:
619
  continue
620
  q = procrustes_alignment_quality(N=N, k=k)
621
- print(f" {k:>5} {q['cos_full_mean']:>9.4f} {q['cos_lifted_mean']:>11.4f}"
622
- f" {q['cos_proj_space']:>9.4f} {q['rot_trace_norm']:>10.4f}"
623
- f" {q['rot_frob_rel']:>10.4f} {q['nn_agreement']:>9.4f}")
 
 
 
 
 
 
624
  all_results.append(q)
625
 
626
- # Summary
627
- print(f"\n {'─'*76}")
628
- print(f" KEY: cos_full = full Procrustes alignment cosine (ceiling)")
629
- print(f" cos_lifted = projected Procrustes lifted back to N-d (what we get)")
630
- print(f" rot_trace = tr(R_full^T @ R_proj)/N (1.0 = same rotation)")
631
- print(f" NN_agree = nearest-neighbor classification agreement (task proxy)")
632
- print(f" {'─'*76}")
633
- print(f" If cos_lifted β‰ˆ cos_full and NN_agree > 0.95, projection is safe.")
 
 
 
 
 
 
 
 
 
 
 
634
 
635
  return all_results
636
 
 
492
 
493
 
494
  def procrustes_alignment_quality(N=48, k=24, n_samples=5000):
495
+ """Compare 5 methods of applying rank-k Procrustes back to N-d.
496
+
497
+ Methods:
498
+ 1. full: Full N-d Procrustes (ceiling)
499
+ 2. pinv: P @ R_k @ pinv(P) β€” naive lift (broken baseline)
500
+ 3. lerp: (1-Ξ±)I + Ξ±*(P @ R_k @ pinv(P)) β€” blend with identity
501
+ 4. slerp: matrix_exp(Ξ± * matrix_log(R_lifted)) β€” geodesic on SO(N)
502
+ 5. subspace: Rotate in-subspace component, preserve orthogonal complement
503
+ 6. stay_k: Don't lift β€” compare in k-d (reference for k-d quality)
504
  """
505
  device = 'cuda'
506
 
507
  # Create two embedding spaces with shared low-rank structure + noise
508
+ shared_rank = min(N // 2, 32)
 
509
  shared_basis = torch.randn(shared_rank, N, device=device)
510
+ shared_basis = torch.linalg.qr(shared_basis.T).Q.T
511
 
 
512
  coeffs_src = torch.randn(n_samples, shared_rank, device=device)
513
  coeffs_tgt = torch.randn(n_samples, shared_rank, device=device) * 0.8 + coeffs_src * 0.5
514
  noise_scale = 0.3
 
516
  source = coeffs_src @ shared_basis + noise_scale * torch.randn(n_samples, N, device=device)
517
  target = coeffs_tgt @ shared_basis + noise_scale * torch.randn(n_samples, N, device=device)
518
 
 
519
  source = source - source.mean(0, keepdim=True)
520
  target = target - target.mean(0, keepdim=True)
521
 
522
+ # ═══ Full N-d Procrustes (ceiling) ═══
523
+ C_full = source.T @ target
524
+ U_f, _, Vh_f = torch.linalg.svd(C_full)
525
+ R_full = U_f @ Vh_f
 
 
526
  aligned_full = source @ R_full
527
+ cos_full = F.cosine_similarity(aligned_full, target, dim=-1).mean().item()
528
 
529
  # ═══ Projected k-d Procrustes ═══
 
 
530
  P = torch.randn(N, k, device=device) / math.sqrt(k)
531
+ # Orthogonalize P for cleaner subspace decomposition
532
+ P = torch.linalg.qr(P).Q # (N, k) orthonormal columns
533
+
534
+ src_proj = source @ P
535
+ tgt_proj = target @ P
536
+
537
+ C_proj = src_proj.T @ tgt_proj
538
+ U_p, _, Vh_p = torch.linalg.svd(C_proj)
539
+ R_k = U_p @ Vh_p # (k, k) optimal rotation in k-d
540
+
541
+ # ═══ Method 1: Naive pinv lift (broken baseline) ═══
542
+ P_pinv = torch.linalg.pinv(P)
543
+ R_pinv = P @ R_k @ P_pinv
544
+ aligned_pinv = source @ R_pinv
545
+ cos_pinv = F.cosine_similarity(aligned_pinv, target, dim=-1).mean().item()
546
+
547
+ # ═══ Method 2: LERP β€” blend projected rotation with identity ═══
548
+ # Test multiple Ξ± values, pick best
549
+ I_N = torch.eye(N, device=device)
550
+ best_lerp_cos = -1.0
551
+ best_lerp_alpha = 0.0
552
+ lerp_results = {}
553
+ for alpha in [0.3, 0.5, 0.7, 0.9, 1.0]:
554
+ R_lerp = (1.0 - alpha) * I_N + alpha * R_pinv
555
+ aligned_lerp = source @ R_lerp
556
+ c = F.cosine_similarity(aligned_lerp, target, dim=-1).mean().item()
557
+ lerp_results[alpha] = c
558
+ if c > best_lerp_cos:
559
+ best_lerp_cos = c
560
+ best_lerp_alpha = alpha
561
+ # Also get NN agreement for best lerp
562
+ R_lerp_best = (1.0 - best_lerp_alpha) * I_N + best_lerp_alpha * R_pinv
563
+ aligned_lerp_best = source @ R_lerp_best
564
+
565
+ # ═══ Method 3: SLERP β€” geodesic interpolation on rotation manifold ═══
566
+ # R_pinv may not be exactly orthogonal, so clean it first
567
+ U_clean, _, Vh_clean = torch.linalg.svd(R_pinv)
568
+ R_ortho = U_clean @ Vh_clean # closest orthogonal matrix
569
+
570
+ best_slerp_cos = -1.0
571
+ best_slerp_alpha = 0.0
572
+ try:
573
+ log_R = torch.linalg.matrix_log(R_ortho.to(torch.complex64)).real
574
+ slerp_works = True
575
+ except Exception:
576
+ slerp_works = False
577
+ log_R = None
578
+
579
+ if slerp_works:
580
+ for alpha in [0.3, 0.5, 0.7, 0.9, 1.0]:
581
+ R_slerp = torch.matrix_exp(alpha * log_R)
582
+ aligned_slerp = source @ R_slerp
583
+ c = F.cosine_similarity(aligned_slerp, target, dim=-1).mean().item()
584
+ if c > best_slerp_cos:
585
+ best_slerp_cos = c
586
+ best_slerp_alpha = alpha
587
+ R_slerp_best = torch.matrix_exp(best_slerp_alpha * log_R)
588
+ aligned_slerp_best = source @ R_slerp_best
589
+ else:
590
+ best_slerp_cos = cos_pinv
591
+ best_slerp_alpha = -1.0
592
+ aligned_slerp_best = aligned_pinv
593
+
594
+ # ═══ Method 4: Subspace-preserving rotation ═══
595
+ # Decompose source into in-subspace and orthogonal complement
596
+ # P @ P^T is the projector onto the k-d subspace (P has orthonormal columns)
597
+ src_in = source @ P # (n, k) β€” coefficients in subspace
598
+ src_perp = source - src_in @ P.T # (n, N) β€” orthogonal complement
599
+
600
+ # Rotate only the in-subspace component
601
+ src_in_rotated = src_in @ R_k # (n, k) β€” rotated in k-d
602
+ aligned_subspace = src_in_rotated @ P.T + src_perp # lift rotated + add perp back
603
+ cos_subspace = F.cosine_similarity(aligned_subspace, target, dim=-1).mean().item()
604
+
605
+ # ═══ Method 5: Stay in k-d (don't lift, reference) ═══
606
+ aligned_k = src_proj @ R_k
607
+ cos_stay_k = F.cosine_similarity(aligned_k, tgt_proj, dim=-1).mean().item()
608
+
609
+ # ═══ NN agreement for all methods ═══
610
  n_anchor = min(100, n_samples // 2)
 
 
 
 
611
 
612
+ def _nn_agree(aligned_a, aligned_b):
613
+ anc_a, anc_b = aligned_a[:n_anchor], aligned_b[:n_anchor]
614
+ q_a, q_b = aligned_a[n_anchor:], aligned_b[n_anchor:]
615
+ nn_a = (q_a @ anc_a.T).argmax(-1)
616
+ nn_b = (q_b @ anc_b.T).argmax(-1)
617
+ return (nn_a == nn_b).float().mean().item()
618
+
619
+ nn_pinv = _nn_agree(aligned_full, aligned_pinv)
620
+ nn_lerp = _nn_agree(aligned_full, aligned_lerp_best)
621
+ nn_slerp = _nn_agree(aligned_full, aligned_slerp_best)
622
+ nn_subspace = _nn_agree(aligned_full, aligned_subspace)
623
 
624
  return {
625
  'N': N, 'k': k,
626
+ 'cos_full': cos_full,
627
+ 'cos_pinv': cos_pinv,
628
+ 'cos_lerp': best_lerp_cos, 'lerp_alpha': best_lerp_alpha,
629
+ 'cos_slerp': best_slerp_cos, 'slerp_alpha': best_slerp_alpha,
630
+ 'cos_subspace': cos_subspace,
631
+ 'cos_stay_k': cos_stay_k,
632
+ 'nn_pinv': nn_pinv, 'nn_lerp': nn_lerp,
633
+ 'nn_slerp': nn_slerp, 'nn_subspace': nn_subspace,
634
+ 'lerp_all': lerp_results,
635
  }
636
 
637
 
638
  def profile_procrustes_quality():
639
+ """Compare all Procrustes lift-back methods."""
640
+ print(f"\n{'='*120}")
641
+ print(f" PROCRUSTES ALIGNMENT: 5 methods of applying rank-k rotation to N-d space")
642
+ print(f" cos = mean cosine similarity after alignment (higher = better, full = ceiling)")
643
+ print(f" NN = nearest-neighbor agreement with full Procrustes (1.0 = identical downstream)")
644
+ print(f"{'='*120}")
645
 
646
  configs = [
647
+ (32, [8, 16, 24]),
648
+ (48, [8, 16, 24, 32]),
649
+ (64, [8, 16, 24, 32]),
650
+ (96, [16, 24, 32, 48]),
651
+ (128, [16, 24, 32, 48, 64]),
652
  ]
653
 
654
  all_results = []
655
 
656
  for N, ranks in configs:
657
  print(f"\n N={N}:")
658
+ print(f" {'k':>5} {'full':>7} {'pinv':>7} {'lerp':>7} {'(Ξ±)':>4}"
659
+ f" {'slerp':>7} {'(Ξ±)':>4} {'subspc':>7} {'stay_k':>7}"
660
+ f" β”‚ {'nn_pv':>6} {'nn_lr':>6} {'nn_sl':>6} {'nn_ss':>6}")
661
+ print(f" {'─'*105}")
662
 
663
  for k in ranks:
664
  if k >= N:
665
  continue
666
  q = procrustes_alignment_quality(N=N, k=k)
667
+
668
+ sl_alpha = f"{q['slerp_alpha']:.1f}" if q['slerp_alpha'] >= 0 else " err"
669
+
670
+ print(f" {k:>5} {q['cos_full']:>7.4f} {q['cos_pinv']:>7.4f}"
671
+ f" {q['cos_lerp']:>7.4f} {q['lerp_alpha']:>3.1f}"
672
+ f" {q['cos_slerp']:>7.4f} {sl_alpha:>4}"
673
+ f" {q['cos_subspace']:>7.4f} {q['cos_stay_k']:>7.4f}"
674
+ f" β”‚ {q['nn_pinv']:>6.3f} {q['nn_lerp']:>6.3f}"
675
+ f" {q['nn_slerp']:>6.3f} {q['nn_subspace']:>6.3f}")
676
  all_results.append(q)
677
 
678
+ # Winner summary
679
+ print(f"\n {'═'*105}")
680
+ print(f" WINNER PER CONFIG (closest cos to full, highest NN agreement):")
681
+ print(f" {'═'*105}")
682
+ for q in all_results:
683
+ methods = {
684
+ 'pinv': q['cos_pinv'], 'lerp': q['cos_lerp'],
685
+ 'slerp': q['cos_slerp'], 'subspace': q['cos_subspace'],
686
+ }
687
+ best_method = max(methods, key=methods.get)
688
+ best_cos = methods[best_method]
689
+ gap = q['cos_full'] - best_cos
690
+ nn_methods = {
691
+ 'pinv': q['nn_pinv'], 'lerp': q['nn_lerp'],
692
+ 'slerp': q['nn_slerp'], 'subspace': q['nn_subspace'],
693
+ }
694
+ best_nn_method = max(nn_methods, key=nn_methods.get)
695
+ print(f" N={q['N']:>3} k={q['k']:>3}: best_cos={best_method:>8} ({best_cos:.4f}, gap={gap:.4f})"
696
+ f" best_nn={best_nn_method:>8} ({nn_methods[best_nn_method]:.3f})")
697
 
698
  return all_results
699