Create kernel_profiler.py
Browse files- kernel_profiler.py +1419 -0
kernel_profiler.py
ADDED
|
@@ -0,0 +1,1419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
triton_svd_general.py β Generalized batched thin SVD for (B, M, N) matrices.
|
| 3 |
+
|
| 4 |
+
Three strategies, auto-dispatched by N:
|
| 5 |
+
N=2: Fused Triton kernel β closed-form 2Γ2 eigensolve in registers
|
| 6 |
+
N=3: Fused Triton kernel β cyclic Jacobi in registers (from session start)
|
| 7 |
+
Nβ₯4: Gram-Eigh hybrid β Triton G=A^T A + torch.linalg.eigh + Triton U recovery
|
| 8 |
+
|
| 9 |
+
All methods exploit the thin-matrix shortcut: decompose via the NΓN Gram
|
| 10 |
+
matrix G=A^T A rather than working on the full MΓN matrix directly.
|
| 11 |
+
|
| 12 |
+
Mathematical lineage:
|
| 13 |
+
Eckart-Young (1936): G = A^T A β eigenvalues of G = ΟΒ² of A
|
| 14 |
+
Jacobi (1846): Cyclic Givens rotations for symmetric eigendecomposition
|
| 15 |
+
Golub-Reinsch (1970): U = A V S^{-1} recovery
|
| 16 |
+
Batcher (1968): Sorting network for eigenvalue ordering
|
| 17 |
+
|
| 18 |
+
Author: AbstractPhil / Claude Opus 4.6
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import triton
|
| 22 |
+
import triton.language as tl
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
import math
|
| 26 |
+
import time
|
| 27 |
+
import json
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
# β KERNEL 1: Fused SVD for (B, M, 2) β closed-form 2Γ2 eigensolve β
|
| 32 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 33 |
+
|
| 34 |
+
@triton.jit
|
| 35 |
+
def _svd2_kernel(
|
| 36 |
+
A_ptr, U_ptr, S_ptr, Vh_ptr,
|
| 37 |
+
M: tl.constexpr, BLOCK_M: tl.constexpr, EPS: tl.constexpr,
|
| 38 |
+
):
|
| 39 |
+
"""Fused SVD for (M, 2) matrices. One program per batch element.
|
| 40 |
+
|
| 41 |
+
2Γ2 symmetric eigendecomposition is closed-form:
|
| 42 |
+
ΞΈ = 0.5 * atan2(2*g01, g00 - g11)
|
| 43 |
+
c = cos(ΞΈ), s = sin(ΞΈ)
|
| 44 |
+
"""
|
| 45 |
+
bid = tl.program_id(0)
|
| 46 |
+
base = bid * M * 2
|
| 47 |
+
|
| 48 |
+
# Stage 1: G = A^T A (3 accumulators: g00, g01, g11)
|
| 49 |
+
g00 = tl.zeros([], dtype=tl.float32)
|
| 50 |
+
g01 = tl.zeros([], dtype=tl.float32)
|
| 51 |
+
g11 = tl.zeros([], dtype=tl.float32)
|
| 52 |
+
|
| 53 |
+
for block_start in range(0, M, BLOCK_M):
|
| 54 |
+
offs = tl.arange(0, BLOCK_M)
|
| 55 |
+
row_idx = block_start + offs
|
| 56 |
+
mask = row_idx < M
|
| 57 |
+
a0 = tl.load(A_ptr + base + row_idx * 2 + 0, mask=mask, other=0.0).to(tl.float32)
|
| 58 |
+
a1 = tl.load(A_ptr + base + row_idx * 2 + 1, mask=mask, other=0.0).to(tl.float32)
|
| 59 |
+
g00 += tl.sum(a0 * a0)
|
| 60 |
+
g01 += tl.sum(a0 * a1)
|
| 61 |
+
g11 += tl.sum(a1 * a1)
|
| 62 |
+
|
| 63 |
+
# Stage 2: 2Γ2 eigendecomposition via single Jacobi rotation
|
| 64 |
+
# Same formula as the 3Γ3 kernel β no trig needed
|
| 65 |
+
off_diag = g01
|
| 66 |
+
diag_diff = g11 - g00
|
| 67 |
+
abs_off = tl.abs(off_diag)
|
| 68 |
+
tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)
|
| 69 |
+
t = tl.where(abs_off > EPS,
|
| 70 |
+
tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)),
|
| 71 |
+
0.0)
|
| 72 |
+
c = 1.0 / tl.sqrt(1.0 + t * t)
|
| 73 |
+
s = t * c
|
| 74 |
+
|
| 75 |
+
# Eigenvalues after rotation
|
| 76 |
+
eig0 = c * c * g00 - 2.0 * s * c * g01 + s * s * g11
|
| 77 |
+
eig1 = s * s * g00 + 2.0 * s * c * g01 + c * c * g11
|
| 78 |
+
|
| 79 |
+
# Ensure descending order
|
| 80 |
+
s0 = tl.sqrt(tl.maximum(eig0, EPS))
|
| 81 |
+
s1 = tl.sqrt(tl.maximum(eig1, EPS))
|
| 82 |
+
|
| 83 |
+
# V starts as I, Jacobi rotation applied
|
| 84 |
+
v00 = c; v01 = s
|
| 85 |
+
v10 = -s; v11 = c
|
| 86 |
+
|
| 87 |
+
# Sort descending
|
| 88 |
+
do_swap = s0 < s1
|
| 89 |
+
s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1)
|
| 90 |
+
tv = v00; v00 = tl.where(do_swap, v01, v00); v01 = tl.where(do_swap, tv, v01)
|
| 91 |
+
tv = v10; v10 = tl.where(do_swap, v11, v10); v11 = tl.where(do_swap, tv, v11)
|
| 92 |
+
|
| 93 |
+
# Write S
|
| 94 |
+
s_base = bid * 2
|
| 95 |
+
tl.store(S_ptr + s_base + 0, s0)
|
| 96 |
+
tl.store(S_ptr + s_base + 1, s1)
|
| 97 |
+
|
| 98 |
+
# Write Vh = V^T
|
| 99 |
+
vh_base = bid * 4
|
| 100 |
+
tl.store(Vh_ptr + vh_base + 0, v00); tl.store(Vh_ptr + vh_base + 1, v10)
|
| 101 |
+
tl.store(Vh_ptr + vh_base + 2, v01); tl.store(Vh_ptr + vh_base + 3, v11)
|
| 102 |
+
|
| 103 |
+
# Stage 3: U = A @ V @ diag(1/S)
|
| 104 |
+
inv_s0 = 1.0 / (s0 + EPS)
|
| 105 |
+
inv_s1 = 1.0 / (s1 + EPS)
|
| 106 |
+
|
| 107 |
+
for block_start in range(0, M, BLOCK_M):
|
| 108 |
+
offs = tl.arange(0, BLOCK_M)
|
| 109 |
+
row_idx = block_start + offs
|
| 110 |
+
mask = row_idx < M
|
| 111 |
+
a0 = tl.load(A_ptr + base + row_idx * 2 + 0, mask=mask, other=0.0).to(tl.float32)
|
| 112 |
+
a1 = tl.load(A_ptr + base + row_idx * 2 + 1, mask=mask, other=0.0).to(tl.float32)
|
| 113 |
+
u0 = (a0 * v00 + a1 * v10) * inv_s0
|
| 114 |
+
u1 = (a0 * v01 + a1 * v11) * inv_s1
|
| 115 |
+
u_base = bid * M * 2
|
| 116 |
+
tl.store(U_ptr + u_base + row_idx * 2 + 0, u0, mask=mask)
|
| 117 |
+
tl.store(U_ptr + u_base + row_idx * 2 + 1, u1, mask=mask)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def batched_svd2(A, block_m=128):
|
| 121 |
+
"""Fused Triton SVD for (B, M, 2) tensors."""
|
| 122 |
+
assert A.ndim == 3 and A.shape[2] == 2
|
| 123 |
+
B, M, _ = A.shape
|
| 124 |
+
A_f32 = A.contiguous().float()
|
| 125 |
+
U = torch.empty((B, M, 2), dtype=torch.float32, device=A.device)
|
| 126 |
+
S = torch.empty((B, 2), dtype=torch.float32, device=A.device)
|
| 127 |
+
Vh = torch.empty((B, 2, 2), dtype=torch.float32, device=A.device)
|
| 128 |
+
_svd2_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, EPS=1e-12)
|
| 129 |
+
return U, S, Vh
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 133 |
+
# β KERNEL 2: Fused SVD for (B, M, 3) β cyclic Jacobi (original kernel) β
|
| 134 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 135 |
+
|
| 136 |
+
@triton.jit
|
| 137 |
+
def _svd3_kernel(
|
| 138 |
+
A_ptr, U_ptr, S_ptr, Vh_ptr,
|
| 139 |
+
M: tl.constexpr, BLOCK_M: tl.constexpr,
|
| 140 |
+
JACOBI_ITERS: tl.constexpr, EPS: tl.constexpr,
|
| 141 |
+
):
|
| 142 |
+
bid = tl.program_id(0)
|
| 143 |
+
g00 = tl.zeros([], dtype=tl.float32); g01 = tl.zeros([], dtype=tl.float32)
|
| 144 |
+
g02 = tl.zeros([], dtype=tl.float32); g11 = tl.zeros([], dtype=tl.float32)
|
| 145 |
+
g12 = tl.zeros([], dtype=tl.float32); g22 = tl.zeros([], dtype=tl.float32)
|
| 146 |
+
base = bid * M * 3
|
| 147 |
+
for block_start in range(0, M, BLOCK_M):
|
| 148 |
+
offs = tl.arange(0, BLOCK_M); row_idx = block_start + offs; mask = row_idx < M
|
| 149 |
+
a0 = tl.load(A_ptr + base + row_idx * 3 + 0, mask=mask, other=0.0).to(tl.float32)
|
| 150 |
+
a1 = tl.load(A_ptr + base + row_idx * 3 + 1, mask=mask, other=0.0).to(tl.float32)
|
| 151 |
+
a2 = tl.load(A_ptr + base + row_idx * 3 + 2, mask=mask, other=0.0).to(tl.float32)
|
| 152 |
+
g00 += tl.sum(a0*a0); g01 += tl.sum(a0*a1); g02 += tl.sum(a0*a2)
|
| 153 |
+
g11 += tl.sum(a1*a1); g12 += tl.sum(a1*a2); g22 += tl.sum(a2*a2)
|
| 154 |
+
v00=1.0;v01=0.0;v02=0.0;v10=0.0;v11=1.0;v12=0.0;v20=0.0;v21=0.0;v22=1.0
|
| 155 |
+
for _ in range(JACOBI_ITERS):
|
| 156 |
+
off_diag=g01;diag_diff=g11-g00;abs_off=tl.abs(off_diag)
|
| 157 |
+
tau=tl.where(abs_off>EPS,diag_diff/(2.0*off_diag),0.0)
|
| 158 |
+
t=tl.where(abs_off>EPS,tl.where(tau>=0,1.0,-1.0)/(tl.abs(tau)+tl.sqrt(1.0+tau*tau)),0.0)
|
| 159 |
+
c=1.0/tl.sqrt(1.0+t*t);s=t*c
|
| 160 |
+
ng00=c*c*g00-2.0*s*c*g01+s*s*g11;ng11=s*s*g00+2.0*s*c*g01+c*c*g11
|
| 161 |
+
ng02=c*g02-s*g12;ng12=s*g02+c*g12
|
| 162 |
+
g00=ng00;g11=ng11;g01=0.0;g02=ng02;g12=ng12
|
| 163 |
+
nv00=c*v00-s*v01;nv01=s*v00+c*v01;nv10=c*v10-s*v11;nv11=s*v10+c*v11
|
| 164 |
+
nv20=c*v20-s*v21;nv21=s*v20+c*v21
|
| 165 |
+
v00=nv00;v01=nv01;v10=nv10;v11=nv11;v20=nv20;v21=nv21
|
| 166 |
+
off_diag=g02;diag_diff=g22-g00;abs_off=tl.abs(off_diag)
|
| 167 |
+
tau=tl.where(abs_off>EPS,diag_diff/(2.0*off_diag),0.0)
|
| 168 |
+
t=tl.where(abs_off>EPS,tl.where(tau>=0,1.0,-1.0)/(tl.abs(tau)+tl.sqrt(1.0+tau*tau)),0.0)
|
| 169 |
+
c=1.0/tl.sqrt(1.0+t*t);s=t*c
|
| 170 |
+
ng00=c*c*g00-2.0*s*c*g02+s*s*g22;ng22=s*s*g00+2.0*s*c*g02+c*c*g22
|
| 171 |
+
ng01=c*g01-s*g12;ng12b=s*g01+c*g12
|
| 172 |
+
g00=ng00;g22=ng22;g02=0.0;g01=ng01;g12=ng12b
|
| 173 |
+
nv00=c*v00-s*v02;nv02=s*v00+c*v02;nv10=c*v10-s*v12;nv12=s*v10+c*v12
|
| 174 |
+
nv20=c*v20-s*v22;nv22=s*v20+c*v22
|
| 175 |
+
v00=nv00;v02=nv02;v10=nv10;v12=nv12;v20=nv20;v22=nv22
|
| 176 |
+
off_diag=g12;diag_diff=g22-g11;abs_off=tl.abs(off_diag)
|
| 177 |
+
tau=tl.where(abs_off>EPS,diag_diff/(2.0*off_diag),0.0)
|
| 178 |
+
t=tl.where(abs_off>EPS,tl.where(tau>=0,1.0,-1.0)/(tl.abs(tau)+tl.sqrt(1.0+tau*tau)),0.0)
|
| 179 |
+
c=1.0/tl.sqrt(1.0+t*t);s=t*c
|
| 180 |
+
ng11=c*c*g11-2.0*s*c*g12+s*s*g22;ng22=s*s*g11+2.0*s*c*g12+c*c*g22
|
| 181 |
+
ng01=c*g01-s*g02;ng02b=s*g01+c*g02
|
| 182 |
+
g11=ng11;g22=ng22;g12=0.0;g01=ng01;g02=ng02b
|
| 183 |
+
nv01=c*v01-s*v02;nv02=s*v01+c*v02;nv11=c*v11-s*v12;nv12=s*v11+c*v12
|
| 184 |
+
nv21=c*v21-s*v22;nv22=s*v21+c*v22
|
| 185 |
+
v01=nv01;v02=nv02;v11=nv11;v12=nv12;v21=nv21;v22=nv22
|
| 186 |
+
s0=tl.sqrt(tl.maximum(g00,EPS));s1=tl.sqrt(tl.maximum(g11,EPS));s2=tl.sqrt(tl.maximum(g22,EPS))
|
| 187 |
+
do_swap=s0<s1
|
| 188 |
+
s0,s1=tl.where(do_swap,s1,s0),tl.where(do_swap,s0,s1)
|
| 189 |
+
tv=v00;v00=tl.where(do_swap,v01,v00);v01=tl.where(do_swap,tv,v01)
|
| 190 |
+
tv=v10;v10=tl.where(do_swap,v11,v10);v11=tl.where(do_swap,tv,v11)
|
| 191 |
+
tv=v20;v20=tl.where(do_swap,v21,v20);v21=tl.where(do_swap,tv,v21)
|
| 192 |
+
do_swap=s0<s2
|
| 193 |
+
s0,s2=tl.where(do_swap,s2,s0),tl.where(do_swap,s0,s2)
|
| 194 |
+
tv=v00;v00=tl.where(do_swap,v02,v00);v02=tl.where(do_swap,tv,v02)
|
| 195 |
+
tv=v10;v10=tl.where(do_swap,v12,v10);v12=tl.where(do_swap,tv,v12)
|
| 196 |
+
tv=v20;v20=tl.where(do_swap,v22,v20);v22=tl.where(do_swap,tv,v22)
|
| 197 |
+
do_swap=s1<s2
|
| 198 |
+
s1,s2=tl.where(do_swap,s2,s1),tl.where(do_swap,s1,s2)
|
| 199 |
+
tv=v01;v01=tl.where(do_swap,v02,v01);v02=tl.where(do_swap,tv,v02)
|
| 200 |
+
tv=v11;v11=tl.where(do_swap,v12,v11);v12=tl.where(do_swap,tv,v12)
|
| 201 |
+
tv=v21;v21=tl.where(do_swap,v22,v21);v22=tl.where(do_swap,tv,v22)
|
| 202 |
+
s_base=bid*3
|
| 203 |
+
tl.store(S_ptr+s_base+0,s0);tl.store(S_ptr+s_base+1,s1);tl.store(S_ptr+s_base+2,s2)
|
| 204 |
+
vh_base=bid*9
|
| 205 |
+
tl.store(Vh_ptr+vh_base+0,v00);tl.store(Vh_ptr+vh_base+1,v10);tl.store(Vh_ptr+vh_base+2,v20)
|
| 206 |
+
tl.store(Vh_ptr+vh_base+3,v01);tl.store(Vh_ptr+vh_base+4,v11);tl.store(Vh_ptr+vh_base+5,v21)
|
| 207 |
+
tl.store(Vh_ptr+vh_base+6,v02);tl.store(Vh_ptr+vh_base+7,v12);tl.store(Vh_ptr+vh_base+8,v22)
|
| 208 |
+
inv_s0=1.0/(s0+EPS);inv_s1=1.0/(s1+EPS);inv_s2=1.0/(s2+EPS)
|
| 209 |
+
for block_start in range(0, M, BLOCK_M):
|
| 210 |
+
offs=tl.arange(0,BLOCK_M);row_idx=block_start+offs;mask=row_idx<M
|
| 211 |
+
a0=tl.load(A_ptr+base+row_idx*3+0,mask=mask,other=0.0).to(tl.float32)
|
| 212 |
+
a1=tl.load(A_ptr+base+row_idx*3+1,mask=mask,other=0.0).to(tl.float32)
|
| 213 |
+
a2=tl.load(A_ptr+base+row_idx*3+2,mask=mask,other=0.0).to(tl.float32)
|
| 214 |
+
u0=(a0*v00+a1*v10+a2*v20)*inv_s0
|
| 215 |
+
u1=(a0*v01+a1*v11+a2*v21)*inv_s1
|
| 216 |
+
u2=(a0*v02+a1*v12+a2*v22)*inv_s2
|
| 217 |
+
u_base=bid*M*3
|
| 218 |
+
tl.store(U_ptr+u_base+row_idx*3+0,u0,mask=mask)
|
| 219 |
+
tl.store(U_ptr+u_base+row_idx*3+1,u1,mask=mask)
|
| 220 |
+
tl.store(U_ptr+u_base+row_idx*3+2,u2,mask=mask)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def batched_svd3(A, block_m=128, jacobi_iters=6):
|
| 224 |
+
"""Fused Triton SVD for (B, M, 3) tensors."""
|
| 225 |
+
assert A.ndim == 3 and A.shape[2] == 3
|
| 226 |
+
B, M, _ = A.shape
|
| 227 |
+
A_f32 = A.contiguous().float()
|
| 228 |
+
U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device)
|
| 229 |
+
S = torch.empty((B, 3), dtype=torch.float32, device=A.device)
|
| 230 |
+
Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device)
|
| 231 |
+
_svd3_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m,
|
| 232 |
+
JACOBI_ITERS=jacobi_iters, EPS=1e-12)
|
| 233 |
+
return U, S, Vh
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 237 |
+
# β METHOD 3: Gram-Eigh hybrid for general N β
|
| 238 |
+
# β G = A^T A (bmm) β eigh(G) β U = A V / S β
|
| 239 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 240 |
+
|
| 241 |
+
def gram_eigh_svd(A):
|
| 242 |
+
"""Thin SVD via Gram matrix eigendecomposition. Works for any N.
|
| 243 |
+
|
| 244 |
+
Steps:
|
| 245 |
+
1. G = A^T A β (B, N, N) symmetric PSD, via bmm
|
| 246 |
+
2. eigenvalues, V = eigh(G) β ascending order
|
| 247 |
+
3. S = sqrt(eigenvalues) β singular values
|
| 248 |
+
4. U = A @ V / S β left singular vectors
|
| 249 |
+
|
| 250 |
+
Mathematically exact. The Eckart-Young (1936) shortcut.
|
| 251 |
+
"""
|
| 252 |
+
B, M, N = A.shape
|
| 253 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 254 |
+
A_f = A.float()
|
| 255 |
+
G = torch.bmm(A_f.transpose(1, 2), A_f) # (B, N, N)
|
| 256 |
+
eigenvalues, V = torch.linalg.eigh(G) # (B, N), (B, N, N)
|
| 257 |
+
eigenvalues = eigenvalues.flip(-1)
|
| 258 |
+
V = V.flip(-1)
|
| 259 |
+
S = torch.sqrt(eigenvalues.clamp(min=1e-12)) # (B, N)
|
| 260 |
+
U = torch.bmm(A_f, V) / S.unsqueeze(1) # (B, M, N)
|
| 261 |
+
Vh = V.transpose(-2, -1).contiguous() # (B, N, N)
|
| 262 |
+
return U, S, Vh
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 266 |
+
# β METHOD 4: Newton iterative SVD for large N (48+) β
|
| 267 |
+
# β All bmm β zero eigensolvers. Quadratic convergence. β
|
| 268 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 269 |
+
|
| 270 |
+
def newton_svd(A, schulz_iters=10):
|
| 271 |
+
"""Thin SVD using Newton-Schulz whitening + eigh.
|
| 272 |
+
|
| 273 |
+
For (B, M, N) with large N where direct eigh on G is slow.
|
| 274 |
+
|
| 275 |
+
The key insight: Newton-Schulz computes G^{-1/2} via pure bmm (no eigensolver).
|
| 276 |
+
We use this to construct G^{1/2} = G @ G^{-1/2}, which has the SAME eigenvectors
|
| 277 |
+
as G but better conditioning (eigenvalues are sqrt-compressed).
|
| 278 |
+
|
| 279 |
+
Steps:
|
| 280 |
+
1. G = A^T A β bmm
|
| 281 |
+
2. G^{-1/2} via Newton-Schulz β ~10Γ bmm, zero eigensolvers
|
| 282 |
+
3. G^{1/2} = G @ G^{-1/2} β bmm
|
| 283 |
+
4. eigh(G^{1/2}) β V, Ο β eigensolve (better conditioned)
|
| 284 |
+
5. S = ΟΒ² / Ο_from_G^{1/2}... simpler: SΒ² = eigenvalues of G
|
| 285 |
+
6. U = A @ V / S β bmm
|
| 286 |
+
|
| 287 |
+
The Newton-Schulz + eigh combo may be faster than raw eigh(G) because
|
| 288 |
+
G^{1/2} is better conditioned, but the main value of this function is
|
| 289 |
+
providing the _newton_schulz_invsqrt utility for Procrustes whitening.
|
| 290 |
+
"""
|
| 291 |
+
B, M, N = A.shape
|
| 292 |
+
A_f = A.float()
|
| 293 |
+
|
| 294 |
+
# Phase 1: Gram matrix
|
| 295 |
+
G = torch.bmm(A_f.transpose(1, 2), A_f) # (B, N, N)
|
| 296 |
+
|
| 297 |
+
# Phase 2: Eigendecomposition of G directly
|
| 298 |
+
# (Newton-Schulz doesn't help avoid this for SVD β it's the bottleneck)
|
| 299 |
+
eigenvalues, V = torch.linalg.eigh(G) # ascending
|
| 300 |
+
eigenvalues = eigenvalues.flip(-1)
|
| 301 |
+
V = V.flip(-1)
|
| 302 |
+
|
| 303 |
+
S = torch.sqrt(eigenvalues.clamp(min=1e-12))
|
| 304 |
+
|
| 305 |
+
# Phase 3: U recovery
|
| 306 |
+
U = torch.bmm(A_f, V) / S.unsqueeze(1)
|
| 307 |
+
Vh = V.transpose(-2, -1).contiguous()
|
| 308 |
+
|
| 309 |
+
return U, S, Vh
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def newton_schulz_invsqrt(G, iters=10):
|
| 313 |
+
"""Newton-Schulz iteration for G^{-1/2} of batched symmetric PSD matrices.
|
| 314 |
+
|
| 315 |
+
This is the USEFUL part β pure bmm, zero eigensolvers, quadratic convergence.
|
| 316 |
+
Use for Procrustes whitening: W = X @ newton_schulz_invsqrt(X^T X)
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
G: (B, N, N) symmetric PSD matrices
|
| 320 |
+
iters: Number of iterations (10 is conservative, 7 usually sufficient)
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
G^{-1/2}: (B, N, N) inverse square root matrices
|
| 324 |
+
"""
|
| 325 |
+
B, N, _ = G.shape
|
| 326 |
+
device, dtype = G.device, G.dtype
|
| 327 |
+
|
| 328 |
+
# Normalize for convergence: eigenvalues of G/trace must be in (0, 3)
|
| 329 |
+
trace = G.diagonal(dim1=-2, dim2=-1).sum(-1, keepdim=True).unsqueeze(-1)
|
| 330 |
+
trace = trace.clamp(min=1e-8)
|
| 331 |
+
G_norm = G / trace
|
| 332 |
+
|
| 333 |
+
I = torch.eye(N, device=device, dtype=dtype).unsqueeze(0).expand(B, -1, -1)
|
| 334 |
+
Y = G_norm.clone()
|
| 335 |
+
Z = I.clone()
|
| 336 |
+
|
| 337 |
+
# Coupled iteration: Y β (G/c)^{1/2}, Z β (G/c)^{-1/2}
|
| 338 |
+
for _ in range(iters):
|
| 339 |
+
ZY = torch.bmm(Z, Y)
|
| 340 |
+
factor = 1.5 * I - 0.5 * ZY
|
| 341 |
+
Y = torch.bmm(Y, factor)
|
| 342 |
+
Z = torch.bmm(factor, Z)
|
| 343 |
+
|
| 344 |
+
# Z β (G/trace)^{-1/2}, so G^{-1/2} = Z * trace^{-1/2}
|
| 345 |
+
Z = Z / trace.sqrt()
|
| 346 |
+
return Z
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 350 |
+
# β BATCHED PROCRUSTES ALIGNMENT β
|
| 351 |
+
# β Subspace-preserving: rotate in k-d, leave orthogonal complement alone β
|
| 352 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 353 |
+
|
| 354 |
+
def batched_procrustes(source, target, rank=24, whiten=True, schulz_iters=10):
|
| 355 |
+
"""Batched Procrustes alignment with rank-k subspace-preserving rotation.
|
| 356 |
+
|
| 357 |
+
For N β€ 32: runs full N-d Procrustes (sub-ms via gram_eigh).
|
| 358 |
+
For N > 32: projects to rank-d, aligns there, lifts back preserving
|
| 359 |
+
the orthogonal complement exactly.
|
| 360 |
+
|
| 361 |
+
Empirically validated: 1.000 NN agreement with full Procrustes across
|
| 362 |
+
all tested configurations (N=32-128, k=8-64).
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
source: (B, n_samples, N) or (n_samples, N) β source embeddings
|
| 366 |
+
target: (B, n_samples, N) or (n_samples, N) β target embeddings
|
| 367 |
+
rank: Projection rank for large N. Ignored if N β€ 32.
|
| 368 |
+
whiten: If True, apply Newton-Schulz whitening before rotation.
|
| 369 |
+
schulz_iters: Iterations for whitening (if enabled).
|
| 370 |
+
|
| 371 |
+
Returns:
|
| 372 |
+
aligned: same shape as source β source aligned to target
|
| 373 |
+
info: dict with rotation matrix, diagnostics
|
| 374 |
+
"""
|
| 375 |
+
unbatched = source.ndim == 2
|
| 376 |
+
if unbatched:
|
| 377 |
+
source = source.unsqueeze(0)
|
| 378 |
+
target = target.unsqueeze(0)
|
| 379 |
+
|
| 380 |
+
B, n_samples, N = source.shape
|
| 381 |
+
device = source.device
|
| 382 |
+
source_f = source.float()
|
| 383 |
+
target_f = target.float()
|
| 384 |
+
|
| 385 |
+
# Center
|
| 386 |
+
src_mean = source_f.mean(1, keepdim=True)
|
| 387 |
+
tgt_mean = target_f.mean(1, keepdim=True)
|
| 388 |
+
src_c = source_f - src_mean
|
| 389 |
+
tgt_c = target_f - tgt_mean
|
| 390 |
+
|
| 391 |
+
# Whiten if requested (Newton-Schulz, pure bmm)
|
| 392 |
+
if whiten:
|
| 393 |
+
src_cov = torch.bmm(src_c.transpose(1, 2), src_c) / max(n_samples - 1, 1)
|
| 394 |
+
tgt_cov = torch.bmm(tgt_c.transpose(1, 2), tgt_c) / max(n_samples - 1, 1)
|
| 395 |
+
src_W = newton_schulz_invsqrt(src_cov, iters=schulz_iters) # (B, N, N)
|
| 396 |
+
tgt_W = newton_schulz_invsqrt(tgt_cov, iters=schulz_iters)
|
| 397 |
+
src_w = torch.bmm(src_c, src_W)
|
| 398 |
+
tgt_w = torch.bmm(tgt_c, tgt_W)
|
| 399 |
+
# Normalize rows
|
| 400 |
+
src_w = F.normalize(src_w, dim=-1)
|
| 401 |
+
tgt_w = F.normalize(tgt_w, dim=-1)
|
| 402 |
+
else:
|
| 403 |
+
src_w = src_c
|
| 404 |
+
tgt_w = tgt_c
|
| 405 |
+
|
| 406 |
+
use_projection = N > 32 and rank < N
|
| 407 |
+
|
| 408 |
+
if not use_projection:
|
| 409 |
+
# βββ Full N-d Procrustes βββ
|
| 410 |
+
C = torch.bmm(src_w.transpose(1, 2), tgt_w) # (B, N, N)
|
| 411 |
+
U, _, Vh = torch.linalg.svd(C)
|
| 412 |
+
R = torch.bmm(U, Vh) # (B, N, N)
|
| 413 |
+
|
| 414 |
+
aligned_w = torch.bmm(src_w, R)
|
| 415 |
+
|
| 416 |
+
# Unwhiten back to target space
|
| 417 |
+
if whiten:
|
| 418 |
+
tgt_unW = torch.linalg.pinv(tgt_W) # (B, N, N)
|
| 419 |
+
aligned = torch.bmm(aligned_w, tgt_unW) + tgt_mean
|
| 420 |
+
else:
|
| 421 |
+
aligned = aligned_w + tgt_mean
|
| 422 |
+
|
| 423 |
+
cos_after = F.cosine_similarity(
|
| 424 |
+
aligned_w[:, :min(1000, n_samples)],
|
| 425 |
+
tgt_w[:, :min(1000, n_samples)], dim=-1).mean().item()
|
| 426 |
+
|
| 427 |
+
info = {
|
| 428 |
+
'method': 'full',
|
| 429 |
+
'N': N, 'rank': N,
|
| 430 |
+
'rotation': R,
|
| 431 |
+
'cos_after': cos_after,
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
else:
|
| 435 |
+
# βββ Subspace-preserving rank-k Procrustes βββ
|
| 436 |
+
k = min(rank, N - 1)
|
| 437 |
+
|
| 438 |
+
# Orthonormal projection basis via QR
|
| 439 |
+
P_raw = torch.randn(B, N, k, device=device, dtype=torch.float32)
|
| 440 |
+
P = torch.linalg.qr(P_raw).Q # (B, N, k) orthonormal columns
|
| 441 |
+
|
| 442 |
+
# Project to k-d
|
| 443 |
+
src_proj = torch.bmm(src_w, P) # (B, n_samples, k)
|
| 444 |
+
tgt_proj = torch.bmm(tgt_w, P) # (B, n_samples, k)
|
| 445 |
+
|
| 446 |
+
# Procrustes in k-d (cheap β kΓk SVD)
|
| 447 |
+
C_k = torch.bmm(src_proj.transpose(1, 2), tgt_proj) # (B, k, k)
|
| 448 |
+
U_k, _, Vh_k = torch.linalg.svd(C_k)
|
| 449 |
+
R_k = torch.bmm(U_k, Vh_k) # (B, k, k)
|
| 450 |
+
|
| 451 |
+
# Subspace-preserving lift:
|
| 452 |
+
# 1. Decompose source into in-subspace and perpendicular components
|
| 453 |
+
# 2. Rotate only the in-subspace component
|
| 454 |
+
# 3. Add back the perpendicular component untouched
|
| 455 |
+
src_in = torch.bmm(src_w, P) # (B, n_samples, k) β coefficients in subspace
|
| 456 |
+
P_T = P.transpose(1, 2) # (B, k, N)
|
| 457 |
+
src_in_fullspace = torch.bmm(src_in, P_T) # (B, n_samples, N) β back in N-d
|
| 458 |
+
src_perp = src_w - src_in_fullspace # (B, n_samples, N) β orthogonal complement
|
| 459 |
+
|
| 460 |
+
# Rotate in-subspace component
|
| 461 |
+
src_rotated_k = torch.bmm(src_in, R_k) # (B, n_samples, k)
|
| 462 |
+
src_rotated_fullspace = torch.bmm(src_rotated_k, P_T) # (B, n_samples, N)
|
| 463 |
+
|
| 464 |
+
# Recombine
|
| 465 |
+
aligned_w = src_rotated_fullspace + src_perp
|
| 466 |
+
|
| 467 |
+
# Unwhiten
|
| 468 |
+
if whiten:
|
| 469 |
+
tgt_unW = torch.linalg.pinv(tgt_W)
|
| 470 |
+
aligned = torch.bmm(aligned_w, tgt_unW) + tgt_mean
|
| 471 |
+
else:
|
| 472 |
+
aligned = aligned_w + tgt_mean
|
| 473 |
+
|
| 474 |
+
# Diagnostics
|
| 475 |
+
cos_after_full = F.cosine_similarity(
|
| 476 |
+
aligned_w[:, :min(1000, n_samples)],
|
| 477 |
+
tgt_w[:, :min(1000, n_samples)], dim=-1).mean().item()
|
| 478 |
+
cos_after_k = F.cosine_similarity(
|
| 479 |
+
src_rotated_k[:, :min(1000, n_samples)],
|
| 480 |
+
tgt_proj[:, :min(1000, n_samples)], dim=-1).mean().item()
|
| 481 |
+
|
| 482 |
+
info = {
|
| 483 |
+
'method': 'subspace',
|
| 484 |
+
'N': N, 'rank': k,
|
| 485 |
+
'rotation_k': R_k,
|
| 486 |
+
'projection': P,
|
| 487 |
+
'cos_after': cos_after_full,
|
| 488 |
+
'cos_after_k': cos_after_k,
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
if unbatched:
|
| 492 |
+
aligned = aligned.squeeze(0)
|
| 493 |
+
|
| 494 |
+
return aligned, info
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def batched_procrustes_align_pair(source, target, rank=24, whiten=True,
|
| 498 |
+
schulz_iters=10, n_align=10000):
|
| 499 |
+
"""Convenience wrapper: align source to target using a subset, apply to all.
|
| 500 |
+
|
| 501 |
+
Computes alignment on first n_align samples, applies to full source.
|
| 502 |
+
|
| 503 |
+
Args:
|
| 504 |
+
source: (n_samples, N) source embeddings
|
| 505 |
+
target: (n_samples, N) target embeddings
|
| 506 |
+
rank: Projection rank for N > 32
|
| 507 |
+
whiten: Apply Newton-Schulz whitening
|
| 508 |
+
n_align: Number of samples to compute alignment from
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
aligned: (n_samples, N) aligned source
|
| 512 |
+
info: alignment diagnostics
|
| 513 |
+
"""
|
| 514 |
+
N = source.shape[-1]
|
| 515 |
+
n = min(n_align, source.shape[0], target.shape[0])
|
| 516 |
+
|
| 517 |
+
# Compute alignment on subset
|
| 518 |
+
_, info = batched_procrustes(
|
| 519 |
+
source[:n].unsqueeze(0), target[:n].unsqueeze(0),
|
| 520 |
+
rank=rank, whiten=whiten, schulz_iters=schulz_iters)
|
| 521 |
+
|
| 522 |
+
# Apply to full source
|
| 523 |
+
src_f = source.float()
|
| 524 |
+
src_mean = source[:n].float().mean(0, keepdim=True)
|
| 525 |
+
tgt_mean = target[:n].float().mean(0, keepdim=True)
|
| 526 |
+
src_c = src_f - src_mean
|
| 527 |
+
|
| 528 |
+
if info['method'] == 'full':
|
| 529 |
+
R = info['rotation'].squeeze(0) # (N, N)
|
| 530 |
+
if whiten:
|
| 531 |
+
src_cov = (source[:n].float() - src_mean).T @ (source[:n].float() - src_mean) / max(n - 1, 1)
|
| 532 |
+
tgt_cov = (target[:n].float() - tgt_mean).T @ (target[:n].float() - tgt_mean) / max(n - 1, 1)
|
| 533 |
+
src_W = newton_schulz_invsqrt(src_cov.unsqueeze(0)).squeeze(0)
|
| 534 |
+
tgt_W = newton_schulz_invsqrt(tgt_cov.unsqueeze(0)).squeeze(0)
|
| 535 |
+
tgt_unW = torch.linalg.pinv(tgt_W)
|
| 536 |
+
aligned = F.normalize(src_c @ src_W, dim=-1) @ R @ tgt_unW + tgt_mean
|
| 537 |
+
else:
|
| 538 |
+
aligned = src_c @ R + tgt_mean
|
| 539 |
+
else:
|
| 540 |
+
P = info['projection'].squeeze(0) # (N, k)
|
| 541 |
+
R_k = info['rotation_k'].squeeze(0) # (k, k)
|
| 542 |
+
if whiten:
|
| 543 |
+
src_cov = (source[:n].float() - src_mean).T @ (source[:n].float() - src_mean) / max(n - 1, 1)
|
| 544 |
+
tgt_cov = (target[:n].float() - tgt_mean).T @ (target[:n].float() - tgt_mean) / max(n - 1, 1)
|
| 545 |
+
src_W = newton_schulz_invsqrt(src_cov.unsqueeze(0)).squeeze(0)
|
| 546 |
+
tgt_W = newton_schulz_invsqrt(tgt_cov.unsqueeze(0)).squeeze(0)
|
| 547 |
+
tgt_unW = torch.linalg.pinv(tgt_W)
|
| 548 |
+
src_w = F.normalize(src_c @ src_W, dim=-1)
|
| 549 |
+
else:
|
| 550 |
+
src_w = src_c
|
| 551 |
+
|
| 552 |
+
src_in = src_w @ P # (n_all, k)
|
| 553 |
+
src_perp = src_w - src_in @ P.T
|
| 554 |
+
src_rotated = src_in @ R_k @ P.T + src_perp
|
| 555 |
+
|
| 556 |
+
if whiten:
|
| 557 |
+
aligned = src_rotated @ tgt_unW + tgt_mean
|
| 558 |
+
else:
|
| 559 |
+
aligned = src_rotated + tgt_mean
|
| 560 |
+
|
| 561 |
+
return aligned, info
|
| 562 |
+
|
| 563 |
+
def projected_svd(A, target_rank=24, oversampling=8):
|
| 564 |
+
"""Rank-projected thin SVD for (B, M, N) with large N.
|
| 565 |
+
|
| 566 |
+
Projects from N-d to k-d (where k = target_rank + oversampling),
|
| 567 |
+
runs gram_eigh SVD in the smaller space, then lifts results back.
|
| 568 |
+
|
| 569 |
+
This is a simplified randomized SVD (Halko-Martinsson-Tropp 2011).
|
| 570 |
+
|
| 571 |
+
Steps:
|
| 572 |
+
1. P = randn(N, k) / sqrt(k) β random projection matrix
|
| 573 |
+
2. A_proj = A @ P β (B, M, k), fast bmm
|
| 574 |
+
3. U_k, S_k, Vh_k = gram_eigh(A_proj) β cheap: kΓk not NΓN
|
| 575 |
+
4. Vh_full = Vh_k @ P^T β lift back to N-d
|
| 576 |
+
5. U_full = A @ Vh_full^T / S β full U recovery
|
| 577 |
+
|
| 578 |
+
The projection preserves the top-k singular structure via
|
| 579 |
+
the Johnson-Lindenstrauss lemma. Singular values beyond rank k
|
| 580 |
+
are lost (set to zero).
|
| 581 |
+
|
| 582 |
+
Args:
|
| 583 |
+
A: (B, M, N) input tensor
|
| 584 |
+
target_rank: Number of singular values/vectors to recover
|
| 585 |
+
oversampling: Extra dimensions for numerical stability (default 8)
|
| 586 |
+
|
| 587 |
+
Returns:
|
| 588 |
+
U: (B, M, k) β thin left singular vectors (k columns, not N)
|
| 589 |
+
S: (B, k) β top-k singular values, descending
|
| 590 |
+
Vh: (B, k, N) β right singular vectors (k rows in N-d space)
|
| 591 |
+
"""
|
| 592 |
+
B, M, N = A.shape
|
| 593 |
+
A_f = A.float()
|
| 594 |
+
k = min(target_rank + oversampling, N)
|
| 595 |
+
|
| 596 |
+
if k >= N:
|
| 597 |
+
# No point projecting β use gram_eigh but still trim to target_rank
|
| 598 |
+
U_full, S_full, Vh_full = gram_eigh_svd(A)
|
| 599 |
+
tr = min(target_rank, N)
|
| 600 |
+
return U_full[:, :, :tr], S_full[:, :tr], Vh_full[:, :tr, :]
|
| 601 |
+
|
| 602 |
+
# Phase 1: Random projection N β k
|
| 603 |
+
# Gaussian random matrix, seeded per-call for reproducibility within a run
|
| 604 |
+
P = torch.randn(N, k, device=A.device, dtype=torch.float32) / math.sqrt(k)
|
| 605 |
+
|
| 606 |
+
# Phase 2: Project
|
| 607 |
+
A_proj = torch.bmm(A_f, P.unsqueeze(0).expand(B, -1, -1)) # (B, M, k)
|
| 608 |
+
|
| 609 |
+
# Phase 3: SVD in reduced space
|
| 610 |
+
U_k, S_k, Vh_k = gram_eigh_svd(A_proj) # Vh_k is (B, k, k)
|
| 611 |
+
|
| 612 |
+
# Phase 4: Lift Vh back to N-d
|
| 613 |
+
# V_k in projected space: Vh_k^T is (B, k, k)
|
| 614 |
+
# V in original space: V_orig = P @ V_k β (N, k)
|
| 615 |
+
# Vh in original space: Vh_orig = V_k^T @ P^T β (k, N)
|
| 616 |
+
P_batch = P.T.unsqueeze(0).expand(B, -1, -1) # (B, k, N)
|
| 617 |
+
Vh_full = torch.bmm(Vh_k, P_batch) # (B, k, N)
|
| 618 |
+
|
| 619 |
+
# Re-orthogonalize Vh rows (projection introduces small errors)
|
| 620 |
+
Vh_full = torch.linalg.qr(Vh_full.transpose(-2, -1)).Q.transpose(-2, -1) # (B, k, N)
|
| 621 |
+
|
| 622 |
+
# Phase 5: Recover U from A and Vh
|
| 623 |
+
# U = A @ Vh^T / S
|
| 624 |
+
V_full = Vh_full.transpose(-2, -1) # (B, N, k)
|
| 625 |
+
U_full = torch.bmm(A_f, V_full) / S_k.unsqueeze(1).clamp(min=1e-12) # (B, M, k)
|
| 626 |
+
|
| 627 |
+
# Trim to target_rank (drop oversampling dimensions)
|
| 628 |
+
U_out = U_full[:, :, :target_rank]
|
| 629 |
+
S_out = S_k[:, :target_rank]
|
| 630 |
+
Vh_out = Vh_full[:, :target_rank, :]
|
| 631 |
+
|
| 632 |
+
return U_out, S_out, Vh_out
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def projected_svd_quality(A, target_rank=24):
|
| 636 |
+
"""Measure quality of rank-projected SVD vs full SVD.
|
| 637 |
+
|
| 638 |
+
Returns dict with energy_ratio, S_error, recon_error, etc.
|
| 639 |
+
"""
|
| 640 |
+
B, M, N = A.shape
|
| 641 |
+
A_f = A.float()
|
| 642 |
+
|
| 643 |
+
# Full reference
|
| 644 |
+
U_ref, S_ref, Vh_ref = torch.linalg.svd(A_f, full_matrices=False)
|
| 645 |
+
|
| 646 |
+
# Energy in top-k vs total
|
| 647 |
+
total_energy = S_ref.pow(2).sum(dim=-1) # (B,)
|
| 648 |
+
topk_energy = S_ref[:, :target_rank].pow(2).sum(dim=-1)
|
| 649 |
+
energy_ratio = (topk_energy / total_energy.clamp(min=1e-12)).mean().item()
|
| 650 |
+
|
| 651 |
+
# Projected SVD
|
| 652 |
+
U_proj, S_proj, Vh_proj = projected_svd(A, target_rank=target_rank)
|
| 653 |
+
|
| 654 |
+
# Reconstruction error: A vs U_proj @ diag(S_proj) @ Vh_proj
|
| 655 |
+
recon_proj = torch.bmm(U_proj * S_proj.unsqueeze(1), Vh_proj)
|
| 656 |
+
recon_err = (A_f - recon_proj).pow(2).mean().sqrt().item()
|
| 657 |
+
|
| 658 |
+
# Full-rank reconstruction for reference floor
|
| 659 |
+
recon_full = torch.bmm(U_ref * S_ref.unsqueeze(1), Vh_ref)
|
| 660 |
+
recon_ref = (A_f - recon_full).pow(2).mean().sqrt().item()
|
| 661 |
+
|
| 662 |
+
# Truncated reference: best possible rank-k approximation (Eckart-Young)
|
| 663 |
+
recon_trunc = torch.bmm(
|
| 664 |
+
U_ref[:, :, :target_rank] * S_ref[:, :target_rank].unsqueeze(1),
|
| 665 |
+
Vh_ref[:, :target_rank, :])
|
| 666 |
+
recon_trunc_err = (A_f - recon_trunc).pow(2).mean().sqrt().item()
|
| 667 |
+
|
| 668 |
+
# Singular value agreement (top-k)
|
| 669 |
+
s_err = (S_proj - S_ref[:, :target_rank]).abs().mean().item()
|
| 670 |
+
s_rel_err = (s_err / S_ref[:, :target_rank].abs().mean().item()) if S_ref[:, :target_rank].abs().mean().item() > 1e-8 else 0.0
|
| 671 |
+
|
| 672 |
+
# Subspace agreement: how well do the projected V directions match true V?
|
| 673 |
+
# cos(principal angles) between subspaces
|
| 674 |
+
V_proj = Vh_proj.transpose(-2, -1) # (B, N, k)
|
| 675 |
+
V_ref = Vh_ref[:, :target_rank, :].transpose(-2, -1) # (B, N, k)
|
| 676 |
+
cross = torch.bmm(V_proj.transpose(-2, -1), V_ref) # (B, k, k)
|
| 677 |
+
svs = torch.linalg.svdvals(cross) # (B, k) β cosines of principal angles
|
| 678 |
+
subspace_cos = svs.mean().item()
|
| 679 |
+
|
| 680 |
+
return {
|
| 681 |
+
'energy_ratio': energy_ratio,
|
| 682 |
+
'recon_proj': recon_err,
|
| 683 |
+
'recon_full': recon_ref,
|
| 684 |
+
'recon_trunc': recon_trunc_err,
|
| 685 |
+
's_err': s_err,
|
| 686 |
+
's_rel_err': s_rel_err,
|
| 687 |
+
'subspace_cos': subspace_cos,
|
| 688 |
+
}
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def procrustes_alignment_quality(N=48, k=24, n_samples=5000):
|
| 692 |
+
"""Compare 5 methods of applying rank-k Procrustes back to N-d.
|
| 693 |
+
|
| 694 |
+
Methods:
|
| 695 |
+
1. full: Full N-d Procrustes (ceiling)
|
| 696 |
+
2. pinv: P @ R_k @ pinv(P) β naive lift (broken baseline)
|
| 697 |
+
3. lerp: (1-Ξ±)I + Ξ±*(P @ R_k @ pinv(P)) β blend with identity
|
| 698 |
+
4. slerp: matrix_exp(Ξ± * matrix_log(R_lifted)) β geodesic on SO(N)
|
| 699 |
+
5. subspace: Rotate in-subspace component, preserve orthogonal complement
|
| 700 |
+
6. stay_k: Don't lift β compare in k-d (reference for k-d quality)
|
| 701 |
+
"""
|
| 702 |
+
device = 'cuda'
|
| 703 |
+
|
| 704 |
+
# Create two embedding spaces with shared low-rank structure + noise
|
| 705 |
+
shared_rank = min(N // 2, 32)
|
| 706 |
+
shared_basis = torch.randn(shared_rank, N, device=device)
|
| 707 |
+
shared_basis = torch.linalg.qr(shared_basis.T).Q.T
|
| 708 |
+
|
| 709 |
+
coeffs_src = torch.randn(n_samples, shared_rank, device=device)
|
| 710 |
+
coeffs_tgt = torch.randn(n_samples, shared_rank, device=device) * 0.8 + coeffs_src * 0.5
|
| 711 |
+
noise_scale = 0.3
|
| 712 |
+
|
| 713 |
+
source = coeffs_src @ shared_basis + noise_scale * torch.randn(n_samples, N, device=device)
|
| 714 |
+
target = coeffs_tgt @ shared_basis + noise_scale * torch.randn(n_samples, N, device=device)
|
| 715 |
+
|
| 716 |
+
source = source - source.mean(0, keepdim=True)
|
| 717 |
+
target = target - target.mean(0, keepdim=True)
|
| 718 |
+
|
| 719 |
+
# βββ Full N-d Procrustes (ceiling) βββ
|
| 720 |
+
C_full = source.T @ target
|
| 721 |
+
U_f, _, Vh_f = torch.linalg.svd(C_full)
|
| 722 |
+
R_full = U_f @ Vh_f
|
| 723 |
+
aligned_full = source @ R_full
|
| 724 |
+
cos_full = F.cosine_similarity(aligned_full, target, dim=-1).mean().item()
|
| 725 |
+
|
| 726 |
+
# βββ Projected k-d Procrustes βββ
|
| 727 |
+
P = torch.randn(N, k, device=device) / math.sqrt(k)
|
| 728 |
+
# Orthogonalize P for cleaner subspace decomposition
|
| 729 |
+
P = torch.linalg.qr(P).Q # (N, k) orthonormal columns
|
| 730 |
+
|
| 731 |
+
src_proj = source @ P
|
| 732 |
+
tgt_proj = target @ P
|
| 733 |
+
|
| 734 |
+
C_proj = src_proj.T @ tgt_proj
|
| 735 |
+
U_p, _, Vh_p = torch.linalg.svd(C_proj)
|
| 736 |
+
R_k = U_p @ Vh_p # (k, k) optimal rotation in k-d
|
| 737 |
+
|
| 738 |
+
# βββ Method 1: Naive pinv lift (broken baseline) βββ
|
| 739 |
+
P_pinv = torch.linalg.pinv(P)
|
| 740 |
+
R_pinv = P @ R_k @ P_pinv
|
| 741 |
+
aligned_pinv = source @ R_pinv
|
| 742 |
+
cos_pinv = F.cosine_similarity(aligned_pinv, target, dim=-1).mean().item()
|
| 743 |
+
|
| 744 |
+
# βββ Method 2: LERP β blend projected rotation with identity βββ
|
| 745 |
+
# Test multiple Ξ± values, pick best
|
| 746 |
+
I_N = torch.eye(N, device=device)
|
| 747 |
+
best_lerp_cos = -1.0
|
| 748 |
+
best_lerp_alpha = 0.0
|
| 749 |
+
lerp_results = {}
|
| 750 |
+
for alpha in [0.3, 0.5, 0.7, 0.9, 1.0]:
|
| 751 |
+
R_lerp = (1.0 - alpha) * I_N + alpha * R_pinv
|
| 752 |
+
aligned_lerp = source @ R_lerp
|
| 753 |
+
c = F.cosine_similarity(aligned_lerp, target, dim=-1).mean().item()
|
| 754 |
+
lerp_results[alpha] = c
|
| 755 |
+
if c > best_lerp_cos:
|
| 756 |
+
best_lerp_cos = c
|
| 757 |
+
best_lerp_alpha = alpha
|
| 758 |
+
# Also get NN agreement for best lerp
|
| 759 |
+
R_lerp_best = (1.0 - best_lerp_alpha) * I_N + best_lerp_alpha * R_pinv
|
| 760 |
+
aligned_lerp_best = source @ R_lerp_best
|
| 761 |
+
|
| 762 |
+
# βββ Method 3: SLERP β geodesic interpolation on rotation manifold βββ
|
| 763 |
+
# R_pinv may not be exactly orthogonal, so clean it first
|
| 764 |
+
U_clean, _, Vh_clean = torch.linalg.svd(R_pinv)
|
| 765 |
+
R_ortho = U_clean @ Vh_clean # closest orthogonal matrix
|
| 766 |
+
|
| 767 |
+
best_slerp_cos = -1.0
|
| 768 |
+
best_slerp_alpha = 0.0
|
| 769 |
+
try:
|
| 770 |
+
log_R = torch.linalg.matrix_log(R_ortho.to(torch.complex64)).real
|
| 771 |
+
slerp_works = True
|
| 772 |
+
except Exception:
|
| 773 |
+
slerp_works = False
|
| 774 |
+
log_R = None
|
| 775 |
+
|
| 776 |
+
if slerp_works:
|
| 777 |
+
for alpha in [0.3, 0.5, 0.7, 0.9, 1.0]:
|
| 778 |
+
R_slerp = torch.matrix_exp(alpha * log_R)
|
| 779 |
+
aligned_slerp = source @ R_slerp
|
| 780 |
+
c = F.cosine_similarity(aligned_slerp, target, dim=-1).mean().item()
|
| 781 |
+
if c > best_slerp_cos:
|
| 782 |
+
best_slerp_cos = c
|
| 783 |
+
best_slerp_alpha = alpha
|
| 784 |
+
R_slerp_best = torch.matrix_exp(best_slerp_alpha * log_R)
|
| 785 |
+
aligned_slerp_best = source @ R_slerp_best
|
| 786 |
+
else:
|
| 787 |
+
best_slerp_cos = cos_pinv
|
| 788 |
+
best_slerp_alpha = -1.0
|
| 789 |
+
aligned_slerp_best = aligned_pinv
|
| 790 |
+
|
| 791 |
+
# βββ Method 4: Subspace-preserving rotation βββ
|
| 792 |
+
# Decompose source into in-subspace and orthogonal complement
|
| 793 |
+
# P @ P^T is the projector onto the k-d subspace (P has orthonormal columns)
|
| 794 |
+
src_in = source @ P # (n, k) β coefficients in subspace
|
| 795 |
+
src_perp = source - src_in @ P.T # (n, N) β orthogonal complement
|
| 796 |
+
|
| 797 |
+
# Rotate only the in-subspace component
|
| 798 |
+
src_in_rotated = src_in @ R_k # (n, k) β rotated in k-d
|
| 799 |
+
aligned_subspace = src_in_rotated @ P.T + src_perp # lift rotated + add perp back
|
| 800 |
+
cos_subspace = F.cosine_similarity(aligned_subspace, target, dim=-1).mean().item()
|
| 801 |
+
|
| 802 |
+
# βββ Method 5: Stay in k-d (don't lift, reference) βββ
|
| 803 |
+
aligned_k = src_proj @ R_k
|
| 804 |
+
cos_stay_k = F.cosine_similarity(aligned_k, tgt_proj, dim=-1).mean().item()
|
| 805 |
+
|
| 806 |
+
# βββ NN agreement for all methods βββ
|
| 807 |
+
n_anchor = min(100, n_samples // 2)
|
| 808 |
+
|
| 809 |
+
def _nn_agree(aligned_a, aligned_b):
|
| 810 |
+
anc_a, anc_b = aligned_a[:n_anchor], aligned_b[:n_anchor]
|
| 811 |
+
q_a, q_b = aligned_a[n_anchor:], aligned_b[n_anchor:]
|
| 812 |
+
nn_a = (q_a @ anc_a.T).argmax(-1)
|
| 813 |
+
nn_b = (q_b @ anc_b.T).argmax(-1)
|
| 814 |
+
return (nn_a == nn_b).float().mean().item()
|
| 815 |
+
|
| 816 |
+
nn_pinv = _nn_agree(aligned_full, aligned_pinv)
|
| 817 |
+
nn_lerp = _nn_agree(aligned_full, aligned_lerp_best)
|
| 818 |
+
nn_slerp = _nn_agree(aligned_full, aligned_slerp_best)
|
| 819 |
+
nn_subspace = _nn_agree(aligned_full, aligned_subspace)
|
| 820 |
+
|
| 821 |
+
return {
|
| 822 |
+
'N': N, 'k': k,
|
| 823 |
+
'cos_full': cos_full,
|
| 824 |
+
'cos_pinv': cos_pinv,
|
| 825 |
+
'cos_lerp': best_lerp_cos, 'lerp_alpha': best_lerp_alpha,
|
| 826 |
+
'cos_slerp': best_slerp_cos, 'slerp_alpha': best_slerp_alpha,
|
| 827 |
+
'cos_subspace': cos_subspace,
|
| 828 |
+
'cos_stay_k': cos_stay_k,
|
| 829 |
+
'nn_pinv': nn_pinv, 'nn_lerp': nn_lerp,
|
| 830 |
+
'nn_slerp': nn_slerp, 'nn_subspace': nn_subspace,
|
| 831 |
+
'lerp_all': lerp_results,
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
def profile_procrustes_quality():
|
| 836 |
+
"""Compare all Procrustes lift-back methods."""
|
| 837 |
+
print(f"\n{'='*120}")
|
| 838 |
+
print(f" PROCRUSTES ALIGNMENT: 5 methods of applying rank-k rotation to N-d space")
|
| 839 |
+
print(f" cos = mean cosine similarity after alignment (higher = better, full = ceiling)")
|
| 840 |
+
print(f" NN = nearest-neighbor agreement with full Procrustes (1.0 = identical downstream)")
|
| 841 |
+
print(f"{'='*120}")
|
| 842 |
+
|
| 843 |
+
configs = [
|
| 844 |
+
(32, [8, 16, 24]),
|
| 845 |
+
(48, [8, 16, 24, 32]),
|
| 846 |
+
(64, [8, 16, 24, 32]),
|
| 847 |
+
(96, [16, 24, 32, 48]),
|
| 848 |
+
(128, [16, 24, 32, 48, 64]),
|
| 849 |
+
]
|
| 850 |
+
|
| 851 |
+
all_results = []
|
| 852 |
+
|
| 853 |
+
for N, ranks in configs:
|
| 854 |
+
print(f"\n N={N}:")
|
| 855 |
+
print(f" {'k':>5} {'full':>7} {'pinv':>7} {'lerp':>7} {'(Ξ±)':>4}"
|
| 856 |
+
f" {'slerp':>7} {'(Ξ±)':>4} {'subspc':>7} {'stay_k':>7}"
|
| 857 |
+
f" β {'nn_pv':>6} {'nn_lr':>6} {'nn_sl':>6} {'nn_ss':>6}")
|
| 858 |
+
print(f" {'β'*105}")
|
| 859 |
+
|
| 860 |
+
for k in ranks:
|
| 861 |
+
if k >= N:
|
| 862 |
+
continue
|
| 863 |
+
q = procrustes_alignment_quality(N=N, k=k)
|
| 864 |
+
|
| 865 |
+
sl_alpha = f"{q['slerp_alpha']:.1f}" if q['slerp_alpha'] >= 0 else " err"
|
| 866 |
+
|
| 867 |
+
print(f" {k:>5} {q['cos_full']:>7.4f} {q['cos_pinv']:>7.4f}"
|
| 868 |
+
f" {q['cos_lerp']:>7.4f} {q['lerp_alpha']:>3.1f}"
|
| 869 |
+
f" {q['cos_slerp']:>7.4f} {sl_alpha:>4}"
|
| 870 |
+
f" {q['cos_subspace']:>7.4f} {q['cos_stay_k']:>7.4f}"
|
| 871 |
+
f" β {q['nn_pinv']:>6.3f} {q['nn_lerp']:>6.3f}"
|
| 872 |
+
f" {q['nn_slerp']:>6.3f} {q['nn_subspace']:>6.3f}")
|
| 873 |
+
all_results.append(q)
|
| 874 |
+
|
| 875 |
+
# Winner summary
|
| 876 |
+
print(f"\n {'β'*105}")
|
| 877 |
+
print(f" WINNER PER CONFIG (closest cos to full, highest NN agreement):")
|
| 878 |
+
print(f" {'β'*105}")
|
| 879 |
+
for q in all_results:
|
| 880 |
+
methods = {
|
| 881 |
+
'pinv': q['cos_pinv'], 'lerp': q['cos_lerp'],
|
| 882 |
+
'slerp': q['cos_slerp'], 'subspace': q['cos_subspace'],
|
| 883 |
+
}
|
| 884 |
+
best_method = max(methods, key=methods.get)
|
| 885 |
+
best_cos = methods[best_method]
|
| 886 |
+
gap = q['cos_full'] - best_cos
|
| 887 |
+
nn_methods = {
|
| 888 |
+
'pinv': q['nn_pinv'], 'lerp': q['nn_lerp'],
|
| 889 |
+
'slerp': q['nn_slerp'], 'subspace': q['nn_subspace'],
|
| 890 |
+
}
|
| 891 |
+
best_nn_method = max(nn_methods, key=nn_methods.get)
|
| 892 |
+
print(f" N={q['N']:>3} k={q['k']:>3}: best_cos={best_method:>8} ({best_cos:.4f}, gap={gap:.4f})"
|
| 893 |
+
f" best_nn={best_nn_method:>8} ({nn_methods[best_nn_method]:.3f})")
|
| 894 |
+
|
| 895 |
+
return all_results
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
def batched_svd(A, method='auto', block_m=128, newton=False, target_rank=None):
|
| 899 |
+
"""Batched thin SVD for (B, M, N) tensors. M >> N.
|
| 900 |
+
|
| 901 |
+
Args:
|
| 902 |
+
A: (B, M, N) CUDA tensor
|
| 903 |
+
method: 'auto', 'triton', 'gram_eigh', 'newton', 'projected', 'torch'
|
| 904 |
+
block_m: Tile size for Triton kernels (N=2,3)
|
| 905 |
+
newton: If True, auto dispatch uses newton_svd for Nβ₯48
|
| 906 |
+
target_rank: For projected method, or auto when Nβ₯48.
|
| 907 |
+
If set, auto uses projected SVD for Nβ₯48 (fast, approximate).
|
| 908 |
+
Default None = use gram_eigh (exact, slow for Nβ₯48).
|
| 909 |
+
|
| 910 |
+
Dispatch table (method='auto'):
|
| 911 |
+
N=2: Fused Triton (closed-form)
|
| 912 |
+
N=3: Fused Triton (cyclic Jacobi)
|
| 913 |
+
N=4-47: Gram + eigh
|
| 914 |
+
Nβ₯48 target_rank set: Projected SVD (projectβcheap SVDβlift)
|
| 915 |
+
Nβ₯48 newton=True: Newton SVD (eigh internally)
|
| 916 |
+
Nβ₯48 default: Gram + eigh (slow but exact)
|
| 917 |
+
|
| 918 |
+
Returns: U, S, Vh β singular values descending.
|
| 919 |
+
Shapes depend on method:
|
| 920 |
+
- Full methods: U(B,M,N), S(B,N), Vh(B,N,N)
|
| 921 |
+
- Projected: U(B,M,k), S(B,k), Vh(B,k,N) where k=target_rank
|
| 922 |
+
"""
|
| 923 |
+
assert A.ndim == 3, f"Expected (B, M, N), got shape {A.shape}"
|
| 924 |
+
assert A.is_cuda, "Input must be on CUDA"
|
| 925 |
+
B, M, N = A.shape
|
| 926 |
+
assert M >= N, f"Thin SVD requires M >= N, got M={M}, N={N}"
|
| 927 |
+
|
| 928 |
+
if method == 'auto':
|
| 929 |
+
if N == 2:
|
| 930 |
+
return batched_svd2(A, block_m)
|
| 931 |
+
elif N == 3:
|
| 932 |
+
return batched_svd3(A, block_m)
|
| 933 |
+
elif target_rank is not None and N >= 48:
|
| 934 |
+
return projected_svd(A, target_rank=target_rank)
|
| 935 |
+
elif newton and N >= 48:
|
| 936 |
+
return newton_svd(A)
|
| 937 |
+
else:
|
| 938 |
+
return gram_eigh_svd(A)
|
| 939 |
+
|
| 940 |
+
elif method == 'triton':
|
| 941 |
+
if N == 2:
|
| 942 |
+
return batched_svd2(A, block_m)
|
| 943 |
+
elif N == 3:
|
| 944 |
+
return batched_svd3(A, block_m)
|
| 945 |
+
else:
|
| 946 |
+
raise ValueError(f"Fused Triton kernel only available for N=2,3, got N={N}")
|
| 947 |
+
|
| 948 |
+
elif method == 'gram_eigh':
|
| 949 |
+
return gram_eigh_svd(A)
|
| 950 |
+
|
| 951 |
+
elif method == 'newton':
|
| 952 |
+
return newton_svd(A)
|
| 953 |
+
|
| 954 |
+
elif method == 'projected':
|
| 955 |
+
rank = target_rank or min(N // 2, 32)
|
| 956 |
+
return projected_svd(A, target_rank=rank)
|
| 957 |
+
|
| 958 |
+
elif method == 'torch':
|
| 959 |
+
return torch.linalg.svd(A.float(), full_matrices=False)
|
| 960 |
+
|
| 961 |
+
else:
|
| 962 |
+
raise ValueError(f"Unknown method '{method}'. Use: auto, triton, gram_eigh, newton, projected, torch")
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 966 |
+
# β CORRECTNESS VALIDATION β
|
| 967 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 968 |
+
|
| 969 |
+
def validate_svd(A, U, S, Vh, label=""):
|
| 970 |
+
"""Check SVD correctness: reconstruction, orthogonality, singular values."""
|
| 971 |
+
B, M, N = A.shape
|
| 972 |
+
A_f = A.float()
|
| 973 |
+
|
| 974 |
+
# Reconstruction: A β U @ diag(S) @ Vh
|
| 975 |
+
recon = torch.bmm(U * S.unsqueeze(1), Vh)
|
| 976 |
+
recon_err = (A_f - recon).abs().max().item()
|
| 977 |
+
|
| 978 |
+
# Orthogonality: U^T U β I
|
| 979 |
+
UtU = torch.bmm(U.transpose(1, 2), U)
|
| 980 |
+
eye = torch.eye(N, device=A.device).expand(B, -1, -1)
|
| 981 |
+
orth_err = (UtU - eye).abs().max().item()
|
| 982 |
+
|
| 983 |
+
# Singular values should be non-negative and descending
|
| 984 |
+
s_min = S.min().item()
|
| 985 |
+
s_sorted = (S[:, :-1] >= S[:, 1:] - 1e-6).all().item()
|
| 986 |
+
|
| 987 |
+
# Reference comparison
|
| 988 |
+
U_ref, S_ref, Vh_ref = torch.linalg.svd(A_f, full_matrices=False)
|
| 989 |
+
s_err = (S - S_ref).abs().max().item()
|
| 990 |
+
recon_ref = (A_f - torch.bmm(U_ref * S_ref.unsqueeze(1), Vh_ref)).abs().max().item()
|
| 991 |
+
|
| 992 |
+
tag = f"[{label}] " if label else ""
|
| 993 |
+
passed = recon_err < max(recon_ref * 3, 1e-3) and orth_err < 1e-2 and s_min >= -1e-6
|
| 994 |
+
status = "PASS" if passed else "FAIL"
|
| 995 |
+
|
| 996 |
+
print(f" {tag}N={N:>3}: S_err={s_err:.2e} recon={recon_err:.2e} (ref={recon_ref:.2e})"
|
| 997 |
+
f" orth={orth_err:.2e} desc={s_sorted} [{status}]")
|
| 998 |
+
return passed
|
| 999 |
+
|
| 1000 |
+
|
| 1001 |
+
def run_validation(B=64, M=1024):
|
| 1002 |
+
"""Validate all methods across N values."""
|
| 1003 |
+
print(f"\n{'='*70}")
|
| 1004 |
+
print(f" CORRECTNESS VALIDATION (B={B}, M={M})")
|
| 1005 |
+
print(f"{'='*70}")
|
| 1006 |
+
|
| 1007 |
+
all_pass = True
|
| 1008 |
+
|
| 1009 |
+
for N in [2, 3, 4, 5, 6, 8, 10, 16, 32, 48, 64, 96, 128]:
|
| 1010 |
+
if N > M:
|
| 1011 |
+
continue
|
| 1012 |
+
A = torch.randn(B, M, N, device="cuda", dtype=torch.float32)
|
| 1013 |
+
|
| 1014 |
+
# Auto method
|
| 1015 |
+
U, S, Vh = batched_svd(A, method='auto')
|
| 1016 |
+
p = validate_svd(A, U, S, Vh, label="auto")
|
| 1017 |
+
all_pass = all_pass and p
|
| 1018 |
+
|
| 1019 |
+
# Explicit Triton kernel validation (N=2,3)
|
| 1020 |
+
if N <= 3:
|
| 1021 |
+
Ut, St, Vht = batched_svd(A, method='triton')
|
| 1022 |
+
pt = validate_svd(A, Ut, St, Vht, label="triton")
|
| 1023 |
+
all_pass = all_pass and pt
|
| 1024 |
+
|
| 1025 |
+
# Gram-eigh for comparison (if N > 3)
|
| 1026 |
+
if N > 3:
|
| 1027 |
+
U2, S2, Vh2 = batched_svd(A, method='gram_eigh')
|
| 1028 |
+
p2 = validate_svd(A, U2, S2, Vh2, label="gram")
|
| 1029 |
+
all_pass = all_pass and p2
|
| 1030 |
+
|
| 1031 |
+
# Newton for comparison (if N >= 8)
|
| 1032 |
+
if N >= 8:
|
| 1033 |
+
U3, S3, Vh3 = newton_svd(A)
|
| 1034 |
+
p3 = validate_svd(A, U3, S3, Vh3, label="newton")
|
| 1035 |
+
all_pass = all_pass and p3
|
| 1036 |
+
|
| 1037 |
+
print(f"\n {'ALL PASSED' if all_pass else 'SOME FAILURES'}")
|
| 1038 |
+
|
| 1039 |
+
# ββ Procrustes alignment validation ββ
|
| 1040 |
+
print(f"\n{'='*70}")
|
| 1041 |
+
print(f" PROCRUSTES ALIGNMENT VALIDATION")
|
| 1042 |
+
print(f"{'='*70}")
|
| 1043 |
+
|
| 1044 |
+
for N in [16, 32, 48, 64, 128]:
|
| 1045 |
+
n_samp = 2000
|
| 1046 |
+
# Create correlated source/target
|
| 1047 |
+
shared = torch.randn(n_samp, N, device='cuda')
|
| 1048 |
+
source = shared + 0.3 * torch.randn(n_samp, N, device='cuda')
|
| 1049 |
+
target = shared + 0.3 * torch.randn(n_samp, N, device='cuda')
|
| 1050 |
+
|
| 1051 |
+
rank = min(24, N - 1)
|
| 1052 |
+
aligned, info = batched_procrustes(
|
| 1053 |
+
source.unsqueeze(0), target.unsqueeze(0),
|
| 1054 |
+
rank=rank, whiten=True)
|
| 1055 |
+
aligned = aligned.squeeze(0)
|
| 1056 |
+
|
| 1057 |
+
cos_before = F.cosine_similarity(source, target, dim=-1).mean().item()
|
| 1058 |
+
cos_after = F.cosine_similarity(aligned, target, dim=-1).mean().item()
|
| 1059 |
+
improved = cos_after > cos_before
|
| 1060 |
+
|
| 1061 |
+
print(f" N={N:>3} rank={rank:>3} method={info['method']:>8}:"
|
| 1062 |
+
f" cos {cos_before:.4f} β {cos_after:.4f}"
|
| 1063 |
+
f" {'IMPROVED' if improved else 'WORSE'}")
|
| 1064 |
+
|
| 1065 |
+
# Test unbatched interface
|
| 1066 |
+
source_ub = torch.randn(1000, 48, device='cuda')
|
| 1067 |
+
target_ub = torch.randn(1000, 48, device='cuda') * 0.5 + source_ub * 0.5
|
| 1068 |
+
aligned_ub, info_ub = batched_procrustes(source_ub, target_ub, rank=24)
|
| 1069 |
+
assert aligned_ub.shape == source_ub.shape, f"Shape mismatch: {aligned_ub.shape} vs {source_ub.shape}"
|
| 1070 |
+
print(f" Unbatched API: shape {aligned_ub.shape} β method={info_ub['method']}")
|
| 1071 |
+
|
| 1072 |
+
# Test batched_procrustes_align_pair
|
| 1073 |
+
aligned_pair, info_pair = batched_procrustes_align_pair(
|
| 1074 |
+
source_ub, target_ub, rank=24, n_align=500)
|
| 1075 |
+
assert aligned_pair.shape == source_ub.shape
|
| 1076 |
+
cos_pair = F.cosine_similarity(aligned_pair, target_ub, dim=-1).mean().item()
|
| 1077 |
+
print(f" Align-pair API: cos={cos_pair:.4f} method={info_pair['method']}")
|
| 1078 |
+
|
| 1079 |
+
print(f" PROCRUSTES VALIDATION COMPLETE")
|
| 1080 |
+
|
| 1081 |
+
return all_pass
|
| 1082 |
+
|
| 1083 |
+
|
| 1084 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1085 |
+
# β BENCHMARKING β
|
| 1086 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1087 |
+
|
| 1088 |
+
def _cuda_timer(fn, warmup=20, iters=80):
|
| 1089 |
+
"""CUDA-event-timed benchmark. Returns (mean_ms, std_ms, median_ms)."""
|
| 1090 |
+
for _ in range(warmup):
|
| 1091 |
+
fn()
|
| 1092 |
+
torch.cuda.synchronize()
|
| 1093 |
+
|
| 1094 |
+
starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
| 1095 |
+
ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
|
| 1096 |
+
for i in range(iters):
|
| 1097 |
+
starts[i].record(); fn(); ends[i].record()
|
| 1098 |
+
torch.cuda.synchronize()
|
| 1099 |
+
|
| 1100 |
+
times = torch.tensor([starts[i].elapsed_time(ends[i]) for i in range(iters)])
|
| 1101 |
+
return times.mean().item(), times.std().item(), times.median().item()
|
| 1102 |
+
|
| 1103 |
+
|
| 1104 |
+
def profile_n_sweep(B=512, M=1024):
|
| 1105 |
+
"""Sweep N from 2 to 128. Compare all methods including projected SVD."""
|
| 1106 |
+
device_name = torch.cuda.get_device_name(0)
|
| 1107 |
+
print(f"\n{'='*110}")
|
| 1108 |
+
print(f" N-DIMENSION SWEEP β {device_name}")
|
| 1109 |
+
print(f" B={B}, M={M}")
|
| 1110 |
+
print(f"{'='*110}")
|
| 1111 |
+
print(f" {'N':>4} {'Triton':>10} {'Gram':>10} {'Newton':>10}"
|
| 1112 |
+
f" {'Projβ24':>10} {'Projβ16':>10} {'Torch':>10} {'Best':>8} {'Speedup':>8}")
|
| 1113 |
+
print(f" {'β'*106}")
|
| 1114 |
+
|
| 1115 |
+
results = []
|
| 1116 |
+
n_values = [2, 3, 4, 5, 6, 7, 8, 10, 12, 16, 20, 24, 32, 48, 64, 96, 128]
|
| 1117 |
+
|
| 1118 |
+
def _fmt(ms):
|
| 1119 |
+
if ms != ms: # nan
|
| 1120 |
+
return f"{'β':>10}"
|
| 1121 |
+
return f"{ms:>8.3f}ms"
|
| 1122 |
+
|
| 1123 |
+
for N in n_values:
|
| 1124 |
+
if N > M:
|
| 1125 |
+
continue
|
| 1126 |
+
A = torch.randn(B, M, N, device="cuda", dtype=torch.float32)
|
| 1127 |
+
|
| 1128 |
+
triton_ms = float('nan')
|
| 1129 |
+
if N <= 3:
|
| 1130 |
+
triton_ms, _, _ = _cuda_timer(lambda: batched_svd(A, method='triton'))
|
| 1131 |
+
|
| 1132 |
+
torch_ms, _, _ = _cuda_timer(lambda: torch.linalg.svd(A, full_matrices=False))
|
| 1133 |
+
gram_ms, _, _ = _cuda_timer(lambda: gram_eigh_svd(A))
|
| 1134 |
+
|
| 1135 |
+
newton_ms = float('nan')
|
| 1136 |
+
if N >= 8:
|
| 1137 |
+
newton_ms, _, _ = _cuda_timer(lambda: newton_svd(A))
|
| 1138 |
+
|
| 1139 |
+
proj24_ms = float('nan')
|
| 1140 |
+
if N >= 32:
|
| 1141 |
+
proj24_ms, _, _ = _cuda_timer(lambda: projected_svd(A, target_rank=min(24, N-1)))
|
| 1142 |
+
|
| 1143 |
+
proj16_ms = float('nan')
|
| 1144 |
+
if N >= 24:
|
| 1145 |
+
proj16_ms, _, _ = _cuda_timer(lambda: projected_svd(A, target_rank=min(16, N-1)))
|
| 1146 |
+
|
| 1147 |
+
# Determine best
|
| 1148 |
+
times = {'torch': torch_ms, 'gram': gram_ms}
|
| 1149 |
+
if N <= 3: times['triton'] = triton_ms
|
| 1150 |
+
if N >= 8: times['newton'] = newton_ms
|
| 1151 |
+
if N >= 32: times['proj24'] = proj24_ms
|
| 1152 |
+
if N >= 24: times['proj16'] = proj16_ms
|
| 1153 |
+
best = min(times, key=times.get)
|
| 1154 |
+
speedup = torch_ms / (times[best] + 1e-9)
|
| 1155 |
+
|
| 1156 |
+
print(f" {N:>4} {_fmt(triton_ms)} {_fmt(gram_ms)} {_fmt(newton_ms)}"
|
| 1157 |
+
f" {_fmt(proj24_ms)} {_fmt(proj16_ms)} {_fmt(torch_ms)}"
|
| 1158 |
+
f" {best:>8} {speedup:>7.1f}x")
|
| 1159 |
+
|
| 1160 |
+
row = {'N': N, 'B': B, 'M': M, 'torch_ms': round(torch_ms, 4),
|
| 1161 |
+
'gram_ms': round(gram_ms, 4), 'best': best,
|
| 1162 |
+
'speedup_vs_torch': round(speedup, 3)}
|
| 1163 |
+
for k, v in [('triton_ms', triton_ms), ('newton_ms', newton_ms),
|
| 1164 |
+
('proj24_ms', proj24_ms), ('proj16_ms', proj16_ms)]:
|
| 1165 |
+
if v == v: row[k] = round(v, 4)
|
| 1166 |
+
results.append(row)
|
| 1167 |
+
|
| 1168 |
+
del A; torch.cuda.empty_cache()
|
| 1169 |
+
|
| 1170 |
+
return results
|
| 1171 |
+
|
| 1172 |
+
|
| 1173 |
+
def profile_projection_quality(B=256, M=1024):
|
| 1174 |
+
"""Measure projection quality: how much information does rank-k SVD preserve?
|
| 1175 |
+
|
| 1176 |
+
For each N, tests multiple target_rank values. Reports:
|
| 1177 |
+
- Energy ratio: fraction of total singular value energy in top-k
|
| 1178 |
+
- Reconstruction error: projected vs full SVD
|
| 1179 |
+
- Subspace agreement: cosine of principal angles between subspaces
|
| 1180 |
+
- Timing: projected vs full SVD
|
| 1181 |
+
"""
|
| 1182 |
+
print(f"\n{'='*100}")
|
| 1183 |
+
print(f" PROJECTION QUALITY ANALYSIS β B={B}, M={M}")
|
| 1184 |
+
print(f" Question: can rank-k SVD approximate rank-N SVD?")
|
| 1185 |
+
print(f"{'='*100}")
|
| 1186 |
+
|
| 1187 |
+
configs = [
|
| 1188 |
+
# (N, [target_ranks to test])
|
| 1189 |
+
(32, [8, 12, 16, 24]),
|
| 1190 |
+
(48, [8, 12, 16, 24, 32]),
|
| 1191 |
+
(64, [8, 12, 16, 24, 32, 48]),
|
| 1192 |
+
(96, [8, 16, 24, 32, 48, 64]),
|
| 1193 |
+
(128, [8, 16, 24, 32, 48, 64, 96]),
|
| 1194 |
+
]
|
| 1195 |
+
|
| 1196 |
+
all_results = []
|
| 1197 |
+
|
| 1198 |
+
for N, ranks in configs:
|
| 1199 |
+
if N > M:
|
| 1200 |
+
continue
|
| 1201 |
+
|
| 1202 |
+
print(f"\n N={N}:")
|
| 1203 |
+
print(f" {'k':>5} {'Energy%':>8} {'Recon_proj':>11} {'Recon_trunc':>12}"
|
| 1204 |
+
f" {'S_rel_err':>10} {'Subspace':>9} {'Proj ms':>10} {'Full ms':>10} {'Speedup':>8}")
|
| 1205 |
+
print(f" {'β'*96}")
|
| 1206 |
+
|
| 1207 |
+
A = torch.randn(B, M, N, device="cuda", dtype=torch.float32)
|
| 1208 |
+
|
| 1209 |
+
# Time full SVD once
|
| 1210 |
+
full_ms, _, _ = _cuda_timer(lambda: gram_eigh_svd(A), warmup=10, iters=40)
|
| 1211 |
+
|
| 1212 |
+
for k in ranks:
|
| 1213 |
+
if k >= N:
|
| 1214 |
+
continue
|
| 1215 |
+
|
| 1216 |
+
q = projected_svd_quality(A, target_rank=k)
|
| 1217 |
+
proj_ms, _, _ = _cuda_timer(
|
| 1218 |
+
lambda: projected_svd(A, target_rank=k), warmup=10, iters=40)
|
| 1219 |
+
|
| 1220 |
+
speedup = full_ms / (proj_ms + 1e-9)
|
| 1221 |
+
|
| 1222 |
+
print(f" {k:>5} {q['energy_ratio']*100:>7.2f}% {q['recon_proj']:>11.2e}"
|
| 1223 |
+
f" {q['recon_trunc']:>12.2e} {q['s_rel_err']:>10.4f}"
|
| 1224 |
+
f" {q['subspace_cos']:>9.4f} {proj_ms:>8.3f}ms {full_ms:>8.3f}ms"
|
| 1225 |
+
f" {speedup:>7.1f}x")
|
| 1226 |
+
|
| 1227 |
+
all_results.append({
|
| 1228 |
+
'N': N, 'k': k, 'B': B, 'M': M,
|
| 1229 |
+
'energy_ratio': round(q['energy_ratio'], 6),
|
| 1230 |
+
'recon_proj': round(q['recon_proj'], 8),
|
| 1231 |
+
'recon_trunc': round(q['recon_trunc'], 8),
|
| 1232 |
+
's_rel_err': round(q['s_rel_err'], 6),
|
| 1233 |
+
'subspace_cos': round(q['subspace_cos'], 6),
|
| 1234 |
+
'proj_ms': round(proj_ms, 4),
|
| 1235 |
+
'full_ms': round(full_ms, 4),
|
| 1236 |
+
})
|
| 1237 |
+
|
| 1238 |
+
del A; torch.cuda.empty_cache()
|
| 1239 |
+
|
| 1240 |
+
# Summary table
|
| 1241 |
+
print(f"\n {'β'*70}")
|
| 1242 |
+
print(f" SUMMARY: Recommended target_rank per N")
|
| 1243 |
+
print(f" (β₯99% energy, β₯0.99 subspace cos, best speedup)")
|
| 1244 |
+
print(f" {'β'*70}")
|
| 1245 |
+
for N, ranks in configs:
|
| 1246 |
+
good = [r for r in all_results if r['N'] == N
|
| 1247 |
+
and r['energy_ratio'] >= 0.99 and r['subspace_cos'] >= 0.99]
|
| 1248 |
+
if good:
|
| 1249 |
+
best = min(good, key=lambda r: r['k'])
|
| 1250 |
+
print(f" N={N:>3}: k={best['k']:>3} β {best['energy_ratio']*100:.1f}% energy,"
|
| 1251 |
+
f" subspace={best['subspace_cos']:.4f},"
|
| 1252 |
+
f" {best['full_ms']/best['proj_ms']:.1f}x speedup")
|
| 1253 |
+
else:
|
| 1254 |
+
# Find best available
|
| 1255 |
+
available = [r for r in all_results if r['N'] == N]
|
| 1256 |
+
if available:
|
| 1257 |
+
best = max(available, key=lambda r: r['energy_ratio'])
|
| 1258 |
+
print(f" N={N:>3}: best k={best['k']:>3} β {best['energy_ratio']*100:.1f}% energy,"
|
| 1259 |
+
f" subspace={best['subspace_cos']:.4f} (below 99% threshold)")
|
| 1260 |
+
|
| 1261 |
+
return all_results
|
| 1262 |
+
|
| 1263 |
+
|
| 1264 |
+
def profile_batch_sweep(N=3, M=1024):
|
| 1265 |
+
"""Sweep batch size for a fixed N. Shows scaling behavior."""
|
| 1266 |
+
print(f"\n{'='*70}")
|
| 1267 |
+
print(f" BATCH SWEEP β N={N}, M={M}")
|
| 1268 |
+
print(f"{'='*70}")
|
| 1269 |
+
print(f" {'B':>6} {'Auto ms':>10} {'Torch ms':>10} {'Speedup':>8} {'img/s':>12}")
|
| 1270 |
+
print(f" {'β'*52}")
|
| 1271 |
+
|
| 1272 |
+
batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
| 1273 |
+
results = []
|
| 1274 |
+
|
| 1275 |
+
for B in batch_sizes:
|
| 1276 |
+
try:
|
| 1277 |
+
A = torch.randn(B, M, N, device="cuda", dtype=torch.float32)
|
| 1278 |
+
except RuntimeError:
|
| 1279 |
+
print(f" {B:>6} OOM")
|
| 1280 |
+
break
|
| 1281 |
+
|
| 1282 |
+
auto_mean, _, _ = _cuda_timer(lambda: batched_svd(A, method='auto'))
|
| 1283 |
+
torch_mean, _, _ = _cuda_timer(
|
| 1284 |
+
lambda: torch.linalg.svd(A, full_matrices=False))
|
| 1285 |
+
|
| 1286 |
+
speedup = torch_mean / (auto_mean + 1e-9)
|
| 1287 |
+
ips = B / (auto_mean / 1000)
|
| 1288 |
+
|
| 1289 |
+
print(f" {B:>6} {auto_mean:>8.3f}ms {torch_mean:>8.3f}ms {speedup:>7.2f}x {ips:>11,.0f}")
|
| 1290 |
+
results.append({'B': B, 'N': N, 'M': M,
|
| 1291 |
+
'auto_ms': round(auto_mean, 4), 'torch_ms': round(torch_mean, 4),
|
| 1292 |
+
'speedup': round(speedup, 3)})
|
| 1293 |
+
del A; torch.cuda.empty_cache()
|
| 1294 |
+
|
| 1295 |
+
return results
|
| 1296 |
+
|
| 1297 |
+
|
| 1298 |
+
def profile_spatial_sweep(N=3, B=512):
|
| 1299 |
+
"""Sweep spatial dimension M for a fixed N. Shows tiling efficiency."""
|
| 1300 |
+
print(f"\n{'='*70}")
|
| 1301 |
+
print(f" SPATIAL SWEEP β N={N}, B={B}")
|
| 1302 |
+
print(f"{'='*70}")
|
| 1303 |
+
print(f" {'M':>6} {'~HxW':>8} {'Auto ms':>10} {'Torch ms':>10} {'Speedup':>8}")
|
| 1304 |
+
print(f" {'β'*48}")
|
| 1305 |
+
|
| 1306 |
+
m_values = [16, 64, 256, 512, 1024, 2048, 4096, 8192, 16384]
|
| 1307 |
+
results = []
|
| 1308 |
+
|
| 1309 |
+
for M in m_values:
|
| 1310 |
+
A = torch.randn(B, M, N, device="cuda", dtype=torch.float32)
|
| 1311 |
+
hw = int(M**0.5)
|
| 1312 |
+
tag = f"{hw}Γ{hw}" if hw * hw == M else f"{M}"
|
| 1313 |
+
|
| 1314 |
+
auto_mean, _, _ = _cuda_timer(lambda: batched_svd(A, method='auto'))
|
| 1315 |
+
torch_mean, _, _ = _cuda_timer(
|
| 1316 |
+
lambda: torch.linalg.svd(A, full_matrices=False))
|
| 1317 |
+
|
| 1318 |
+
speedup = torch_mean / (auto_mean + 1e-9)
|
| 1319 |
+
print(f" {M:>6} {tag:>8} {auto_mean:>8.3f}ms {torch_mean:>8.3f}ms {speedup:>7.2f}x")
|
| 1320 |
+
results.append({'M': M, 'N': N, 'B': B,
|
| 1321 |
+
'auto_ms': round(auto_mean, 4), 'torch_ms': round(torch_mean, 4),
|
| 1322 |
+
'speedup': round(speedup, 3)})
|
| 1323 |
+
del A; torch.cuda.empty_cache()
|
| 1324 |
+
|
| 1325 |
+
return results
|
| 1326 |
+
|
| 1327 |
+
|
| 1328 |
+
def profile_crossover_detail(M=1024, B=512):
|
| 1329 |
+
"""Fine-grained N sweep around expected crossover points."""
|
| 1330 |
+
print(f"\n{'='*70}")
|
| 1331 |
+
print(f" CROSSOVER DETAIL β B={B}, M={M}")
|
| 1332 |
+
print(f"{'='*70}")
|
| 1333 |
+
print(f" {'N':>4} {'Gram ms':>10} {'Torch ms':>10} {'Winner':>8} {'Margin':>8}")
|
| 1334 |
+
print(f" {'β'*46}")
|
| 1335 |
+
|
| 1336 |
+
for N in range(2, 65):
|
| 1337 |
+
if N > M:
|
| 1338 |
+
break
|
| 1339 |
+
A = torch.randn(B, M, N, device="cuda", dtype=torch.float32)
|
| 1340 |
+
gram_mean, _, _ = _cuda_timer(lambda: gram_eigh_svd(A), warmup=10, iters=40)
|
| 1341 |
+
torch_mean, _, _ = _cuda_timer(
|
| 1342 |
+
lambda: torch.linalg.svd(A, full_matrices=False), warmup=10, iters=40)
|
| 1343 |
+
|
| 1344 |
+
winner = "gram" if gram_mean < torch_mean else "torch"
|
| 1345 |
+
margin = abs(gram_mean - torch_mean) / min(gram_mean, torch_mean) * 100
|
| 1346 |
+
|
| 1347 |
+
print(f" {N:>4} {gram_mean:>8.3f}ms {torch_mean:>8.3f}ms {winner:>8} {margin:>6.1f}%")
|
| 1348 |
+
del A; torch.cuda.empty_cache()
|
| 1349 |
+
|
| 1350 |
+
|
| 1351 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1352 |
+
# β MAIN β
|
| 1353 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1354 |
+
|
| 1355 |
+
def main():
|
| 1356 |
+
"""Full profiling suite."""
|
| 1357 |
+
assert torch.cuda.is_available(), "CUDA required"
|
| 1358 |
+
device_name = torch.cuda.get_device_name(0)
|
| 1359 |
+
print(f"{'='*80}")
|
| 1360 |
+
print(f" Generalized Batched Thin SVD β Profiling Suite")
|
| 1361 |
+
print(f" Device: {device_name}")
|
| 1362 |
+
print(f"{'='*80}")
|
| 1363 |
+
|
| 1364 |
+
# Correctness first
|
| 1365 |
+
run_validation(B=64, M=1024)
|
| 1366 |
+
|
| 1367 |
+
# Procrustes alignment quality β THE REAL QUESTION
|
| 1368 |
+
# Does rank-k Procrustes produce the same rotation as rank-N?
|
| 1369 |
+
procrustes_results = profile_procrustes_quality()
|
| 1370 |
+
|
| 1371 |
+
# Projection quality analysis β energy/reconstruction perspective
|
| 1372 |
+
proj_results = profile_projection_quality(B=256, M=1024)
|
| 1373 |
+
|
| 1374 |
+
# N dimension sweep β timing comparison
|
| 1375 |
+
n_results = profile_n_sweep(B=512, M=1024)
|
| 1376 |
+
|
| 1377 |
+
# Skip batch/spatial/crossover sweeps by default β uncomment if needed
|
| 1378 |
+
batch_results = {}
|
| 1379 |
+
spatial_results = {}
|
| 1380 |
+
# for N in [3, 8, 32, 64]:
|
| 1381 |
+
# batch_results[N] = profile_batch_sweep(N=N, M=1024)
|
| 1382 |
+
# for N in [3, 16, 48]:
|
| 1383 |
+
# spatial_results[N] = profile_spatial_sweep(N=N, B=512)
|
| 1384 |
+
# profile_crossover_detail(M=1024, B=512)
|
| 1385 |
+
|
| 1386 |
+
# Summary
|
| 1387 |
+
print(f"\n{'='*80}")
|
| 1388 |
+
print(f" SUMMARY")
|
| 1389 |
+
print(f"{'='*80}")
|
| 1390 |
+
print(f"\n Strategy by N:")
|
| 1391 |
+
print(f" N=2: Fused Triton (closed-form Jacobi rotation)")
|
| 1392 |
+
print(f" N=3: Fused Triton (cyclic Jacobi in registers)")
|
| 1393 |
+
print(f" N=4-32: Gram + eigh (bmm + cuSOLVER eigh) β sub-ms")
|
| 1394 |
+
print(f" N=48+: Projected SVD (Nβk, cheap SVD, lift back) β check quality table")
|
| 1395 |
+
print(f"")
|
| 1396 |
+
print(f" Standalone utilities:")
|
| 1397 |
+
print(f" newton_schulz_invsqrt(G) β batched G^{{-1/2}} via pure bmm")
|
| 1398 |
+
print(f" projected_svd(A, target_rank=k) β rank-k approximate SVD")
|
| 1399 |
+
print(f" projected_svd_quality(A, target_rank) β measure approximation quality")
|
| 1400 |
+
print(f"")
|
| 1401 |
+
print(f" Key question answered: energy_ratio and subspace_cos in quality table")
|
| 1402 |
+
|
| 1403 |
+
# Save results
|
| 1404 |
+
report = {
|
| 1405 |
+
'device': device_name,
|
| 1406 |
+
'procrustes_quality': procrustes_results,
|
| 1407 |
+
'projection_quality': proj_results,
|
| 1408 |
+
'n_sweep': n_results,
|
| 1409 |
+
'batch_sweeps': {str(k): v for k, v in batch_results.items()},
|
| 1410 |
+
'spatial_sweeps': {str(k): v for k, v in spatial_results.items()},
|
| 1411 |
+
}
|
| 1412 |
+
with open('svd_general_profile.json', 'w') as f:
|
| 1413 |
+
json.dump(report, f, indent=2)
|
| 1414 |
+
print(f"\n Results saved to svd_general_profile.json")
|
| 1415 |
+
print(f"{'='*80}")
|
| 1416 |
+
|
| 1417 |
+
|
| 1418 |
+
if __name__ == "__main__":
|
| 1419 |
+
main()
|