OkeyMeta's picture
Release Reframr-RFM-v1-Base public checkpoint
2147ce8 verified
import math
from dataclasses import dataclass
import site
import sys
from pathlib import Path
from .linalg import Matrix, Vector, identity, invert_matrix, matvec
_VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor"
for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"):
if _vendor_path.exists():
vendor_text = str(_vendor_path)
if vendor_text not in sys.path:
sys.path.insert(0, vendor_text)
try:
import numpy as np
except ModuleNotFoundError:
user_site = site.getusersitepackages()
if user_site and user_site not in sys.path:
sys.path.append(user_site)
try:
import numpy as np
except ModuleNotFoundError:
np = None
def hippo_legs_matrix(order: int) -> tuple[Matrix, Vector]:
a_matrix = [[0.0 for _ in range(order)] for _ in range(order)]
b_vector = [0.0 for _ in range(order)]
for row in range(order):
for col in range(order):
if row > col:
a_matrix[row][col] = -math.sqrt(2 * row + 1) * math.sqrt(2 * col + 1)
elif row == col:
a_matrix[row][col] = -(row + 1)
b_vector[row] = math.sqrt(2 * row + 1)
return a_matrix, b_vector
def analytical_embedding_drive(embedding: Vector, state_dim: int) -> Vector:
if not embedding:
return [0.0 for _ in range(state_dim)]
width = len(embedding)
return [
(
embedding[index % width]
+ 0.5 * embedding[(3 * index + 1) % width]
- 0.25 * embedding[(5 * index + 2) % width]
)
for index in range(state_dim)
]
def analytical_embedding_drive_fast(embedding: object, state_dim: int) -> object:
if np is None:
embedding_vector = embedding.tolist() if hasattr(embedding, "tolist") else list(embedding)
return analytical_embedding_drive(embedding_vector, state_dim)
embedding_array = embedding if hasattr(embedding, "shape") else np.asarray(embedding, dtype=np.float64)
if embedding_array.size == 0:
return np.zeros(state_dim, dtype=np.float64)
indices = np.arange(state_dim, dtype=np.int64)
width = int(embedding_array.shape[0])
return (
embedding_array[indices % width]
+ 0.5 * embedding_array[(3 * indices + 1) % width]
- 0.25 * embedding_array[(5 * indices + 2) % width]
)
@dataclass(slots=True)
class AnalyticalMemoryUnit:
state_dim: int
timescale: float
def __post_init__(self) -> None:
a_matrix, b_vector = hippo_legs_matrix(self.state_dim)
self.transition, self.input_projection = self._discretize_transition(
a_matrix,
b_vector,
self.timescale,
)
transition: Matrix = None # type: ignore[assignment]
input_projection: Vector = None # type: ignore[assignment]
transition_array: object | None = None # type: ignore[assignment]
input_projection_array: object | None = None # type: ignore[assignment]
@staticmethod
def _discretize_transition(
a_matrix: Matrix,
b_vector: Vector,
step: float,
) -> tuple[Matrix, Vector]:
implicit_system = [
[
identity_value - step * a_value
for identity_value, a_value in zip(identity_row, a_row)
]
for identity_row, a_row in zip(identity(len(a_matrix)), a_matrix)
]
transition = invert_matrix(implicit_system)
input_projection = matvec(transition, [step * value for value in b_vector])
return transition, input_projection
def step(self, state: Vector, scalar_input: float) -> Vector:
if np is not None and self.transition_array is None:
self.transition_array = np.asarray(self.transition, dtype=np.float64)
self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64)
propagated = matvec(self.transition, state)
return [
propagated[index] + self.input_projection[index] * scalar_input
for index in range(self.state_dim)
]
def step_vector(self, state: Vector, drive: Vector) -> Vector:
propagated = matvec(self.transition, state)
return [
propagated[index] + self.input_projection[index] * drive[index]
for index in range(self.state_dim)
]
def step_fast(self, state: object, scalar_input: float) -> object:
if np is None:
state_vector = state.tolist() if hasattr(state, "tolist") else list(state)
return self.step(state_vector, scalar_input)
if self.transition_array is None or self.input_projection_array is None:
self.transition_array = np.asarray(self.transition, dtype=np.float64)
self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64)
state_array = state if hasattr(state, "shape") else np.asarray(state, dtype=np.float64)
return (self.transition_array @ state_array) + (self.input_projection_array * scalar_input)
def step_vector_fast(self, state: object, drive: object) -> object:
if np is None:
state_vector = state.tolist() if hasattr(state, "tolist") else list(state)
drive_vector = drive.tolist() if hasattr(drive, "tolist") else list(drive)
return self.step_vector(state_vector, drive_vector)
if self.transition_array is None or self.input_projection_array is None:
self.transition_array = np.asarray(self.transition, dtype=np.float64)
self.input_projection_array = np.asarray(self.input_projection, dtype=np.float64)
state_array = state if hasattr(state, "shape") else np.asarray(state, dtype=np.float64)
drive_array = drive if hasattr(drive, "shape") else np.asarray(drive, dtype=np.float64)
return (self.transition_array @ state_array) + (self.input_projection_array * drive_array)