File size: 1,371 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import torch
from einops import repeat
from jaxtyping import Int
from torch import Tensor
Index = Int[Tensor, "n n-1"]
def generate_heterogeneous_index(
n: int,
device: torch.device = torch.device("cpu"),
) -> tuple[Index, Index]:
"""Generate indices for all pairs except self-pairs."""
arange = torch.arange(n, device=device)
# Generate an index that represents the item itself.
index_self = repeat(arange, "h -> h w", w=n - 1)
# Generate an index that represents the other items.
index_other = repeat(arange, "w -> h w", h=n).clone()
index_other += torch.ones((n, n), device=device, dtype=torch.int64).triu()
index_other = index_other[:, :-1]
return index_self, index_other
def generate_heterogeneous_index_transpose(
n: int,
device: torch.device = torch.device("cpu"),
) -> tuple[Index, Index]:
"""Generate an index that can be used to "transpose" the heterogeneous index.
Applying the index a second time inverts the "transpose."
"""
arange = torch.arange(n, device=device)
ones = torch.ones((n, n), device=device, dtype=torch.int64)
index_self = repeat(arange, "w -> h w", h=n).clone()
index_self = index_self + ones.triu()
index_other = repeat(arange, "h -> h w", w=n)
index_other = index_other - (1 - ones.triu())
return index_self[:, :-1], index_other[:, :-1]
|