pong / src /nn /pe.py
chrisxx's picture
Update src/nn/pe.py
a9440e0 verified
import torch as t
import torch.nn as nn
import math
from jaxtyping import Float, Bool, Int
from torch import Tensor
from typing import Optional
from pdb import set_trace
def compute_trig(d_head, n_ctx, C):
thetas = t.exp(-math.log(C)*t.arange(0,d_head,2)/d_head)
thetas = thetas.repeat([2,1]).T.flatten()
positions = t.arange(n_ctx)
all_thetas = positions.unsqueeze(1)*thetas.unsqueeze(0)
sins = t.sin(all_thetas)
coss = t.cos(all_thetas)
return sins, coss
class NumericEncoding(nn.Module):
def __init__(self, C = 1e4, dim = 64, n_max = 10000):
super().__init__()
args = t.exp(- math.log(C) * t.arange(0, dim, 2)/dim)
args = t.arange(n_max)[:, None] * args[None, :]
sins = t.sin(args)
coss = t.cos(args)
pe = t.empty((n_max, dim))
pe[:,::2] = sins
pe[:,1::2] = coss
self.register_buffer("pe", pe)
def forward(self, num):
"""
expects integers between 0 and n_max
"""
assert num.dtype == t.int32 or num.dtype == t.int64, f"wrong dtype {num.dtype}"
return self.pe[num]
class RoPE(nn.Module):
def __init__(self, d_head, n_ctx, C=10000):
super().__init__()
self.d_head = d_head
self.n_ctx = n_ctx
self.C = C
sins, coss = compute_trig(d_head, n_ctx, C)
self.register_buffer('sins', sins.unsqueeze(0).unsqueeze(2))
self.register_buffer('coss', coss.unsqueeze(0).unsqueeze(2))
def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
offset: int = 0):
x = key_or_query
x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
even = t.arange(0, x.shape[-1], 2)
odd = t.arange(1, x.shape[-1],2)
x_perm[:, :, :, even] = -x[:, :, :, odd]
x_perm[:, :, :, odd] = x[:, :, :, even]
assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
return self.coss[:,offset:offset+x.shape[1]]*x + self.sins[:,offset:offset+x.shape[1]]*x_perm
class LearnRoPE(nn.Module):
def __init__(self, d_head, n_ctx, C=10000):
super().__init__()
self.d_head = d_head
self.n_ctx = n_ctx
self.C = C
sins, coss = compute_trig(d_head, n_ctx, C)
self.sins = nn.Parameter(sins.unsqueeze(0).unsqueeze(2))
self.coss = nn.Parameter(coss.unsqueeze(0).unsqueeze(2))
def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
offset: int = 0):
x = key_or_query
x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
even = t.arange(0, x.shape[-1], 2)
odd = t.arange(1, x.shape[-1],2)
x_perm[:, :, :, even] = -x[:, :, :, odd]
x_perm[:, :, :, odd] = x[:, :, :, even]
assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
return self.coss[:,offset:offset+x.shape[1]]*x + self.sins[:,offset:offset+x.shape[1]]*x_perm
class VidRoPE(nn.Module):
def __init__(self, d_head,
d_x,
d_y,
d_t,
ctx_x,
ctx_y,
ctx_t,
C_x,
C_y,
C_t,
toks_per_frame,
n_registers):
super().__init__()
assert d_x + d_y + d_t <= d_head, f"dx + dy + dt > d_head"
self.d_head = d_head
self.d_x = d_x
self.d_y = d_y
self.d_t = d_t
self.ctx_x = ctx_x
self.ctx_y = ctx_y
self.ctx_t = ctx_t
self.C_x = C_x
self.C_y = C_y
self.C_t = C_t
self.toks_per_frame = toks_per_frame
self.n_registers = n_registers
sins_x, coss_x = compute_trig(d_x, ctx_x+1, C_x) # +1 for the register
self.register_buffer("sins_x", sins_x.unsqueeze(0).unsqueeze(2))
self.register_buffer("coss_x", coss_x.unsqueeze(0).unsqueeze(2))
sins_y, coss_y = compute_trig(d_y, ctx_y+1, C_y) # +1 for the register
self.register_buffer("sins_y", sins_y.unsqueeze(0).unsqueeze(2))
self.register_buffer("coss_y", coss_y.unsqueeze(0).unsqueeze(2))
sins_t, coss_t = compute_trig(d_t, ctx_t, C_t)
self.register_buffer("sins_t", sins_t.unsqueeze(0).unsqueeze(2))
self.register_buffer("coss_t", coss_t.unsqueeze(0).unsqueeze(2))
n_frames = ctx_t
# ctx_x should be equal to width
# ctx_y should be equal to height
pos_x = t.arange(self.ctx_x).repeat(self.ctx_y) # w cols with h entries each
pos_x = t.cat([pos_x, t.tensor([self.ctx_x], dtype=t.int32)]) # deal with register
pos_x = pos_x.repeat(n_frames)
pos_y = t.arange(self.ctx_y).repeat_interleave(self.ctx_x) # h rows with w entries each
pos_y = t.cat([pos_y, t.tensor([self.ctx_y], dtype=t.int32)]) # deal with register
pos_y = pos_y.repeat(n_frames)
pos_t = t.arange(n_frames).repeat_interleave(self.toks_per_frame)
self.register_buffer("pos_x", pos_x)
self.register_buffer("pos_y", pos_y)
self.register_buffer("pos_t", pos_t)
def rotate(self, x, pos_idcs, coss, sins):
x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
even = t.arange(0, x.shape[-1], 2, device=x.device)
odd = t.arange(1, x.shape[-1], 2, device=x.device)
x_perm[:, :, :, even] = -x[:, :, :, odd]
x_perm[:, :, :, odd] = x[:, :, :, even]
assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
assert pos_idcs.shape[0] == x.shape[1], f"pos_idcs length {pos_idcs.shape[0]} must match x.shape[1] {x.shape[1]}"
return coss[:,pos_idcs]*x + sins[:,pos_idcs]*x_perm
def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
offset: int = 0):
x = key_or_query
x[:, :, :, :self.d_x] = self.rotate(x[:, :, :, :self.d_x], self.pos_x, self.coss_x, self.sins_x)
x[:, :, :, self.d_x:self.d_x+self.d_y] = self.rotate(x[:, :, :, self.d_x:self.d_x+self.d_y], self.pos_y, self.coss_y, self.sins_y)
x[:, :, :, self.d_x+self.d_y:self.d_x+self.d_y+self.d_t] = self.rotate(x[:, : , :, self.d_x+self.d_y:self.d_x+self.d_y+self.d_t], self.pos_t+(offset//self.toks_per_frame), self.coss_t, self.sins_t)
return x