| """ |
| PHI-SCAN: Physics-informed multi-directional token scanners. |
| Alternates between scan patterns per layer with zero extra parameters. |
| """ |
| import torch |
|
|
| SCAN_PATTERNS = ["row_major", "col_major", "hilbert", "zigzag_diag"] |
|
|
|
|
| def _hilbert_index(n, x, y): |
| d = 0 |
| s = n // 2 |
| while s > 0: |
| rx = 1 if (x & s) > 0 else 0 |
| ry = 1 if (y & s) > 0 else 0 |
| d += s * s * ((3 * rx) ^ ry) |
| if ry == 0: |
| if rx == 1: |
| x = n - 1 - x |
| y = n - 1 - y |
| x, y = y, x |
| s //= 2 |
| return d |
|
|
|
|
| def build_hilbert_permutation(h: int, w: int, device='cpu'): |
| n = max(h, w) |
| n = 1 << (n - 1).bit_length() |
| indices = [-1] * (h * w) |
| for y in range(h): |
| for x in range(w): |
| idx = _hilbert_index(n, x, y) |
| indices[y * w + x] = idx |
| sorted_pairs = sorted(enumerate(indices), key=lambda kv: kv[1]) |
| perm = torch.tensor([i for i, _ in sorted_pairs], dtype=torch.long, device=device) |
| inv = torch.empty_like(perm) |
| inv[perm] = torch.arange(h * w, device=device) |
| return perm, inv |
|
|
|
|
| def build_zigzag_diag_permutation(h: int, w: int, device='cpu'): |
| diag = {} |
| for y in range(h): |
| for x in range(w): |
| s = x + y |
| if s not in diag: |
| diag[s] = [] |
| diag[s].append((y, x)) |
| order = [] |
| flip = False |
| for s in sorted(diag.keys()): |
| cells = diag[s] |
| if flip: |
| cells = cells[::-1] |
| order.extend(cells) |
| flip = not flip |
| perm = torch.tensor([y * w + x for y, x in order], dtype=torch.long, device=device) |
| inv = torch.empty_like(perm) |
| inv[perm] = torch.arange(h * w, device=device) |
| return perm, inv |
|
|
|
|
| def build_scan_permutations(h: int, w: int, device='cpu'): |
| row_perm = torch.arange(h * w, device=device) |
| row_inv = torch.arange(h * w, device=device) |
| col_perm = torch.arange(h * w, device=device).reshape(h, w).t().reshape(-1) |
| col_inv = torch.empty_like(col_perm) |
| col_inv[col_perm] = torch.arange(h * w, device=device) |
| hil_perm, hil_inv = build_hilbert_permutation(h, w, device) |
| zig_perm, zig_inv = build_zigzag_diag_permutation(h, w, device) |
| return { |
| "row_major": (row_perm, row_inv), |
| "col_major": (col_perm, col_inv), |
| "hilbert": (hil_perm, hil_inv), |
| "zigzag_diag": (zig_perm, zig_inv), |
| } |
|
|
|
|
| def apply_scan(x: torch.Tensor, perm: torch.Tensor): |
| B, L, C = x.shape |
| return x[:, perm, :] |
|
|
|
|
| def unscan(x: torch.Tensor, inv: torch.Tensor): |
| B, L, C = x.shape |
| return x[:, inv, :] |
|
|
|
|
| def get_scan_pattern(layer_idx: int): |
| return SCAN_PATTERNS[layer_idx % len(SCAN_PATTERNS)] |
|
|