Update svd_triton_gram_newton.py
Browse files- svd_triton_gram_newton.py +158 -95
svd_triton_gram_newton.py
CHANGED
|
@@ -492,25 +492,23 @@ def projected_svd_quality(A, target_rank=24):
|
|
| 492 |
|
| 493 |
|
| 494 |
def procrustes_alignment_quality(N=48, k=24, n_samples=5000):
|
| 495 |
-
"""
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
1.
|
| 499 |
-
2.
|
| 500 |
-
3.
|
| 501 |
-
4.
|
| 502 |
-
|
| 503 |
-
|
| 504 |
"""
|
| 505 |
device = 'cuda'
|
| 506 |
|
| 507 |
# Create two embedding spaces with shared low-rank structure + noise
|
| 508 |
-
|
| 509 |
-
shared_rank = min(N // 2, 32) # true shared structure
|
| 510 |
shared_basis = torch.randn(shared_rank, N, device=device)
|
| 511 |
-
shared_basis = torch.linalg.qr(shared_basis.T).Q.T
|
| 512 |
|
| 513 |
-
# Source and target share the basis but with different coefficients + noise
|
| 514 |
coeffs_src = torch.randn(n_samples, shared_rank, device=device)
|
| 515 |
coeffs_tgt = torch.randn(n_samples, shared_rank, device=device) * 0.8 + coeffs_src * 0.5
|
| 516 |
noise_scale = 0.3
|
|
@@ -518,119 +516,184 @@ def procrustes_alignment_quality(N=48, k=24, n_samples=5000):
|
|
| 518 |
source = coeffs_src @ shared_basis + noise_scale * torch.randn(n_samples, N, device=device)
|
| 519 |
target = coeffs_tgt @ shared_basis + noise_scale * torch.randn(n_samples, N, device=device)
|
| 520 |
|
| 521 |
-
# Center
|
| 522 |
source = source - source.mean(0, keepdim=True)
|
| 523 |
target = target - target.mean(0, keepdim=True)
|
| 524 |
|
| 525 |
-
# βββ Full N-d Procrustes βββ
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
R_full = U_f @ Vh_f # (N, N) optimal rotation
|
| 530 |
-
|
| 531 |
aligned_full = source @ R_full
|
| 532 |
-
cos_full = F.cosine_similarity(aligned_full, target, dim=-1)
|
| 533 |
|
| 534 |
# βββ Projected k-d Procrustes βββ
|
| 535 |
-
# Project both spaces to k-d, align there, lift back
|
| 536 |
-
# Random projection
|
| 537 |
P = torch.randn(N, k, device=device) / math.sqrt(k)
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
n_anchor = min(100, n_samples // 2)
|
| 573 |
-
anchors_full = aligned_full[:n_anchor]
|
| 574 |
-
anchors_lift = aligned_lifted[:n_anchor]
|
| 575 |
-
queries_full = aligned_full[n_anchor:]
|
| 576 |
-
queries_lift = aligned_lifted[n_anchor:]
|
| 577 |
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
|
| 582 |
return {
|
| 583 |
'N': N, 'k': k,
|
| 584 |
-
'
|
| 585 |
-
'
|
| 586 |
-
'
|
| 587 |
-
'
|
| 588 |
-
'
|
| 589 |
-
'
|
| 590 |
-
'
|
|
|
|
|
|
|
| 591 |
}
|
| 592 |
|
| 593 |
|
| 594 |
def profile_procrustes_quality():
|
| 595 |
-
"""Compare Procrustes
|
| 596 |
-
print(f"\n{'='*
|
| 597 |
-
print(f" PROCRUSTES ALIGNMENT
|
| 598 |
-
print(f"
|
| 599 |
-
print(f"
|
|
|
|
| 600 |
|
| 601 |
configs = [
|
| 602 |
-
(32, [8,
|
| 603 |
-
(48, [8,
|
| 604 |
-
(64, [8,
|
| 605 |
-
(96, [
|
| 606 |
-
(128, [
|
| 607 |
]
|
| 608 |
|
| 609 |
all_results = []
|
| 610 |
|
| 611 |
for N, ranks in configs:
|
| 612 |
print(f"\n N={N}:")
|
| 613 |
-
print(f" {'k':>5} {'
|
| 614 |
-
f" {'
|
| 615 |
-
|
|
|
|
| 616 |
|
| 617 |
for k in ranks:
|
| 618 |
if k >= N:
|
| 619 |
continue
|
| 620 |
q = procrustes_alignment_quality(N=N, k=k)
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
all_results.append(q)
|
| 625 |
|
| 626 |
-
#
|
| 627 |
-
print(f"\n {'
|
| 628 |
-
print(f"
|
| 629 |
-
print(f"
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 634 |
|
| 635 |
return all_results
|
| 636 |
|
|
|
|
| 492 |
|
| 493 |
|
| 494 |
def procrustes_alignment_quality(N=48, k=24, n_samples=5000):
|
| 495 |
+
"""Compare 5 methods of applying rank-k Procrustes back to N-d.
|
| 496 |
+
|
| 497 |
+
Methods:
|
| 498 |
+
1. full: Full N-d Procrustes (ceiling)
|
| 499 |
+
2. pinv: P @ R_k @ pinv(P) β naive lift (broken baseline)
|
| 500 |
+
3. lerp: (1-Ξ±)I + Ξ±*(P @ R_k @ pinv(P)) β blend with identity
|
| 501 |
+
4. slerp: matrix_exp(Ξ± * matrix_log(R_lifted)) β geodesic on SO(N)
|
| 502 |
+
5. subspace: Rotate in-subspace component, preserve orthogonal complement
|
| 503 |
+
6. stay_k: Don't lift β compare in k-d (reference for k-d quality)
|
| 504 |
"""
|
| 505 |
device = 'cuda'
|
| 506 |
|
| 507 |
# Create two embedding spaces with shared low-rank structure + noise
|
| 508 |
+
shared_rank = min(N // 2, 32)
|
|
|
|
| 509 |
shared_basis = torch.randn(shared_rank, N, device=device)
|
| 510 |
+
shared_basis = torch.linalg.qr(shared_basis.T).Q.T
|
| 511 |
|
|
|
|
| 512 |
coeffs_src = torch.randn(n_samples, shared_rank, device=device)
|
| 513 |
coeffs_tgt = torch.randn(n_samples, shared_rank, device=device) * 0.8 + coeffs_src * 0.5
|
| 514 |
noise_scale = 0.3
|
|
|
|
| 516 |
source = coeffs_src @ shared_basis + noise_scale * torch.randn(n_samples, N, device=device)
|
| 517 |
target = coeffs_tgt @ shared_basis + noise_scale * torch.randn(n_samples, N, device=device)
|
| 518 |
|
|
|
|
| 519 |
source = source - source.mean(0, keepdim=True)
|
| 520 |
target = target - target.mean(0, keepdim=True)
|
| 521 |
|
| 522 |
+
# βββ Full N-d Procrustes (ceiling) βββ
|
| 523 |
+
C_full = source.T @ target
|
| 524 |
+
U_f, _, Vh_f = torch.linalg.svd(C_full)
|
| 525 |
+
R_full = U_f @ Vh_f
|
|
|
|
|
|
|
| 526 |
aligned_full = source @ R_full
|
| 527 |
+
cos_full = F.cosine_similarity(aligned_full, target, dim=-1).mean().item()
|
| 528 |
|
| 529 |
# βββ Projected k-d Procrustes βββ
|
|
|
|
|
|
|
| 530 |
P = torch.randn(N, k, device=device) / math.sqrt(k)
|
| 531 |
+
# Orthogonalize P for cleaner subspace decomposition
|
| 532 |
+
P = torch.linalg.qr(P).Q # (N, k) orthonormal columns
|
| 533 |
+
|
| 534 |
+
src_proj = source @ P
|
| 535 |
+
tgt_proj = target @ P
|
| 536 |
+
|
| 537 |
+
C_proj = src_proj.T @ tgt_proj
|
| 538 |
+
U_p, _, Vh_p = torch.linalg.svd(C_proj)
|
| 539 |
+
R_k = U_p @ Vh_p # (k, k) optimal rotation in k-d
|
| 540 |
+
|
| 541 |
+
# βββ Method 1: Naive pinv lift (broken baseline) βββ
|
| 542 |
+
P_pinv = torch.linalg.pinv(P)
|
| 543 |
+
R_pinv = P @ R_k @ P_pinv
|
| 544 |
+
aligned_pinv = source @ R_pinv
|
| 545 |
+
cos_pinv = F.cosine_similarity(aligned_pinv, target, dim=-1).mean().item()
|
| 546 |
+
|
| 547 |
+
# βββ Method 2: LERP β blend projected rotation with identity βββ
|
| 548 |
+
# Test multiple Ξ± values, pick best
|
| 549 |
+
I_N = torch.eye(N, device=device)
|
| 550 |
+
best_lerp_cos = -1.0
|
| 551 |
+
best_lerp_alpha = 0.0
|
| 552 |
+
lerp_results = {}
|
| 553 |
+
for alpha in [0.3, 0.5, 0.7, 0.9, 1.0]:
|
| 554 |
+
R_lerp = (1.0 - alpha) * I_N + alpha * R_pinv
|
| 555 |
+
aligned_lerp = source @ R_lerp
|
| 556 |
+
c = F.cosine_similarity(aligned_lerp, target, dim=-1).mean().item()
|
| 557 |
+
lerp_results[alpha] = c
|
| 558 |
+
if c > best_lerp_cos:
|
| 559 |
+
best_lerp_cos = c
|
| 560 |
+
best_lerp_alpha = alpha
|
| 561 |
+
# Also get NN agreement for best lerp
|
| 562 |
+
R_lerp_best = (1.0 - best_lerp_alpha) * I_N + best_lerp_alpha * R_pinv
|
| 563 |
+
aligned_lerp_best = source @ R_lerp_best
|
| 564 |
+
|
| 565 |
+
# βββ Method 3: SLERP β geodesic interpolation on rotation manifold βββ
|
| 566 |
+
# R_pinv may not be exactly orthogonal, so clean it first
|
| 567 |
+
U_clean, _, Vh_clean = torch.linalg.svd(R_pinv)
|
| 568 |
+
R_ortho = U_clean @ Vh_clean # closest orthogonal matrix
|
| 569 |
+
|
| 570 |
+
best_slerp_cos = -1.0
|
| 571 |
+
best_slerp_alpha = 0.0
|
| 572 |
+
try:
|
| 573 |
+
log_R = torch.linalg.matrix_log(R_ortho.to(torch.complex64)).real
|
| 574 |
+
slerp_works = True
|
| 575 |
+
except Exception:
|
| 576 |
+
slerp_works = False
|
| 577 |
+
log_R = None
|
| 578 |
+
|
| 579 |
+
if slerp_works:
|
| 580 |
+
for alpha in [0.3, 0.5, 0.7, 0.9, 1.0]:
|
| 581 |
+
R_slerp = torch.matrix_exp(alpha * log_R)
|
| 582 |
+
aligned_slerp = source @ R_slerp
|
| 583 |
+
c = F.cosine_similarity(aligned_slerp, target, dim=-1).mean().item()
|
| 584 |
+
if c > best_slerp_cos:
|
| 585 |
+
best_slerp_cos = c
|
| 586 |
+
best_slerp_alpha = alpha
|
| 587 |
+
R_slerp_best = torch.matrix_exp(best_slerp_alpha * log_R)
|
| 588 |
+
aligned_slerp_best = source @ R_slerp_best
|
| 589 |
+
else:
|
| 590 |
+
best_slerp_cos = cos_pinv
|
| 591 |
+
best_slerp_alpha = -1.0
|
| 592 |
+
aligned_slerp_best = aligned_pinv
|
| 593 |
+
|
| 594 |
+
# βββ Method 4: Subspace-preserving rotation βββ
|
| 595 |
+
# Decompose source into in-subspace and orthogonal complement
|
| 596 |
+
# P @ P^T is the projector onto the k-d subspace (P has orthonormal columns)
|
| 597 |
+
src_in = source @ P # (n, k) β coefficients in subspace
|
| 598 |
+
src_perp = source - src_in @ P.T # (n, N) β orthogonal complement
|
| 599 |
+
|
| 600 |
+
# Rotate only the in-subspace component
|
| 601 |
+
src_in_rotated = src_in @ R_k # (n, k) β rotated in k-d
|
| 602 |
+
aligned_subspace = src_in_rotated @ P.T + src_perp # lift rotated + add perp back
|
| 603 |
+
cos_subspace = F.cosine_similarity(aligned_subspace, target, dim=-1).mean().item()
|
| 604 |
+
|
| 605 |
+
# βββ Method 5: Stay in k-d (don't lift, reference) βββ
|
| 606 |
+
aligned_k = src_proj @ R_k
|
| 607 |
+
cos_stay_k = F.cosine_similarity(aligned_k, tgt_proj, dim=-1).mean().item()
|
| 608 |
+
|
| 609 |
+
# βββ NN agreement for all methods βββ
|
| 610 |
n_anchor = min(100, n_samples // 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
|
| 612 |
+
def _nn_agree(aligned_a, aligned_b):
|
| 613 |
+
anc_a, anc_b = aligned_a[:n_anchor], aligned_b[:n_anchor]
|
| 614 |
+
q_a, q_b = aligned_a[n_anchor:], aligned_b[n_anchor:]
|
| 615 |
+
nn_a = (q_a @ anc_a.T).argmax(-1)
|
| 616 |
+
nn_b = (q_b @ anc_b.T).argmax(-1)
|
| 617 |
+
return (nn_a == nn_b).float().mean().item()
|
| 618 |
+
|
| 619 |
+
nn_pinv = _nn_agree(aligned_full, aligned_pinv)
|
| 620 |
+
nn_lerp = _nn_agree(aligned_full, aligned_lerp_best)
|
| 621 |
+
nn_slerp = _nn_agree(aligned_full, aligned_slerp_best)
|
| 622 |
+
nn_subspace = _nn_agree(aligned_full, aligned_subspace)
|
| 623 |
|
| 624 |
return {
|
| 625 |
'N': N, 'k': k,
|
| 626 |
+
'cos_full': cos_full,
|
| 627 |
+
'cos_pinv': cos_pinv,
|
| 628 |
+
'cos_lerp': best_lerp_cos, 'lerp_alpha': best_lerp_alpha,
|
| 629 |
+
'cos_slerp': best_slerp_cos, 'slerp_alpha': best_slerp_alpha,
|
| 630 |
+
'cos_subspace': cos_subspace,
|
| 631 |
+
'cos_stay_k': cos_stay_k,
|
| 632 |
+
'nn_pinv': nn_pinv, 'nn_lerp': nn_lerp,
|
| 633 |
+
'nn_slerp': nn_slerp, 'nn_subspace': nn_subspace,
|
| 634 |
+
'lerp_all': lerp_results,
|
| 635 |
}
|
| 636 |
|
| 637 |
|
| 638 |
def profile_procrustes_quality():
|
| 639 |
+
"""Compare all Procrustes lift-back methods."""
|
| 640 |
+
print(f"\n{'='*120}")
|
| 641 |
+
print(f" PROCRUSTES ALIGNMENT: 5 methods of applying rank-k rotation to N-d space")
|
| 642 |
+
print(f" cos = mean cosine similarity after alignment (higher = better, full = ceiling)")
|
| 643 |
+
print(f" NN = nearest-neighbor agreement with full Procrustes (1.0 = identical downstream)")
|
| 644 |
+
print(f"{'='*120}")
|
| 645 |
|
| 646 |
configs = [
|
| 647 |
+
(32, [8, 16, 24]),
|
| 648 |
+
(48, [8, 16, 24, 32]),
|
| 649 |
+
(64, [8, 16, 24, 32]),
|
| 650 |
+
(96, [16, 24, 32, 48]),
|
| 651 |
+
(128, [16, 24, 32, 48, 64]),
|
| 652 |
]
|
| 653 |
|
| 654 |
all_results = []
|
| 655 |
|
| 656 |
for N, ranks in configs:
|
| 657 |
print(f"\n N={N}:")
|
| 658 |
+
print(f" {'k':>5} {'full':>7} {'pinv':>7} {'lerp':>7} {'(Ξ±)':>4}"
|
| 659 |
+
f" {'slerp':>7} {'(Ξ±)':>4} {'subspc':>7} {'stay_k':>7}"
|
| 660 |
+
f" β {'nn_pv':>6} {'nn_lr':>6} {'nn_sl':>6} {'nn_ss':>6}")
|
| 661 |
+
print(f" {'β'*105}")
|
| 662 |
|
| 663 |
for k in ranks:
|
| 664 |
if k >= N:
|
| 665 |
continue
|
| 666 |
q = procrustes_alignment_quality(N=N, k=k)
|
| 667 |
+
|
| 668 |
+
sl_alpha = f"{q['slerp_alpha']:.1f}" if q['slerp_alpha'] >= 0 else " err"
|
| 669 |
+
|
| 670 |
+
print(f" {k:>5} {q['cos_full']:>7.4f} {q['cos_pinv']:>7.4f}"
|
| 671 |
+
f" {q['cos_lerp']:>7.4f} {q['lerp_alpha']:>3.1f}"
|
| 672 |
+
f" {q['cos_slerp']:>7.4f} {sl_alpha:>4}"
|
| 673 |
+
f" {q['cos_subspace']:>7.4f} {q['cos_stay_k']:>7.4f}"
|
| 674 |
+
f" β {q['nn_pinv']:>6.3f} {q['nn_lerp']:>6.3f}"
|
| 675 |
+
f" {q['nn_slerp']:>6.3f} {q['nn_subspace']:>6.3f}")
|
| 676 |
all_results.append(q)
|
| 677 |
|
| 678 |
+
# Winner summary
|
| 679 |
+
print(f"\n {'β'*105}")
|
| 680 |
+
print(f" WINNER PER CONFIG (closest cos to full, highest NN agreement):")
|
| 681 |
+
print(f" {'β'*105}")
|
| 682 |
+
for q in all_results:
|
| 683 |
+
methods = {
|
| 684 |
+
'pinv': q['cos_pinv'], 'lerp': q['cos_lerp'],
|
| 685 |
+
'slerp': q['cos_slerp'], 'subspace': q['cos_subspace'],
|
| 686 |
+
}
|
| 687 |
+
best_method = max(methods, key=methods.get)
|
| 688 |
+
best_cos = methods[best_method]
|
| 689 |
+
gap = q['cos_full'] - best_cos
|
| 690 |
+
nn_methods = {
|
| 691 |
+
'pinv': q['nn_pinv'], 'lerp': q['nn_lerp'],
|
| 692 |
+
'slerp': q['nn_slerp'], 'subspace': q['nn_subspace'],
|
| 693 |
+
}
|
| 694 |
+
best_nn_method = max(nn_methods, key=nn_methods.get)
|
| 695 |
+
print(f" N={q['N']:>3} k={q['k']:>3}: best_cos={best_method:>8} ({best_cos:.4f}, gap={gap:.4f})"
|
| 696 |
+
f" best_nn={best_nn_method:>8} ({nn_methods[best_nn_method]:.3f})")
|
| 697 |
|
| 698 |
return all_results
|
| 699 |
|