artigen / phi_scan.py
krystv's picture
Upload phi_scan.py
1c281a6 verified
"""
PHI-SCAN: Physics-informed multi-directional token scanners.
Alternates between scan patterns per layer with zero extra parameters.
"""
import torch
SCAN_PATTERNS = ["row_major", "col_major", "hilbert", "zigzag_diag"]
def _hilbert_index(n, x, y):
d = 0
s = n // 2
while s > 0:
rx = 1 if (x & s) > 0 else 0
ry = 1 if (y & s) > 0 else 0
d += s * s * ((3 * rx) ^ ry)
if ry == 0:
if rx == 1:
x = n - 1 - x
y = n - 1 - y
x, y = y, x
s //= 2
return d
def build_hilbert_permutation(h: int, w: int, device='cpu'):
n = max(h, w)
n = 1 << (n - 1).bit_length()
indices = [-1] * (h * w)
for y in range(h):
for x in range(w):
idx = _hilbert_index(n, x, y)
indices[y * w + x] = idx
sorted_pairs = sorted(enumerate(indices), key=lambda kv: kv[1])
perm = torch.tensor([i for i, _ in sorted_pairs], dtype=torch.long, device=device)
inv = torch.empty_like(perm)
inv[perm] = torch.arange(h * w, device=device)
return perm, inv
def build_zigzag_diag_permutation(h: int, w: int, device='cpu'):
diag = {}
for y in range(h):
for x in range(w):
s = x + y
if s not in diag:
diag[s] = []
diag[s].append((y, x))
order = []
flip = False
for s in sorted(diag.keys()):
cells = diag[s]
if flip:
cells = cells[::-1]
order.extend(cells)
flip = not flip
perm = torch.tensor([y * w + x for y, x in order], dtype=torch.long, device=device)
inv = torch.empty_like(perm)
inv[perm] = torch.arange(h * w, device=device)
return perm, inv
def build_scan_permutations(h: int, w: int, device='cpu'):
row_perm = torch.arange(h * w, device=device)
row_inv = torch.arange(h * w, device=device)
col_perm = torch.arange(h * w, device=device).reshape(h, w).t().reshape(-1)
col_inv = torch.empty_like(col_perm)
col_inv[col_perm] = torch.arange(h * w, device=device)
hil_perm, hil_inv = build_hilbert_permutation(h, w, device)
zig_perm, zig_inv = build_zigzag_diag_permutation(h, w, device)
return {
"row_major": (row_perm, row_inv),
"col_major": (col_perm, col_inv),
"hilbert": (hil_perm, hil_inv),
"zigzag_diag": (zig_perm, zig_inv),
}
def apply_scan(x: torch.Tensor, perm: torch.Tensor):
B, L, C = x.shape
return x[:, perm, :]
def unscan(x: torch.Tensor, inv: torch.Tensor):
B, L, C = x.shape
return x[:, inv, :]
def get_scan_pattern(layer_idx: int):
return SCAN_PATTERNS[layer_idx % len(SCAN_PATTERNS)]