motoko-1-1b / preprocessor /feature_extractor.py
hrudu's picture
update
89e5d21
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
import numpy as np
class MotokoFeatureExtractor:
"""Normalize and stack haptic modalities into a single model tensor."""
def __init__(self, config: dict[str, Any]) -> None:
self.config = config
self.max_length = int(config.get("max_length", 2048))
self.padding_value = float(config.get("padding_value", 0.0))
self.eps = float(config.get("normalization", {}).get("eps", 1e-6))
self.modalities = config.get("modalities", {})
@classmethod
def from_config(cls, path: str | Path) -> "MotokoFeatureExtractor":
with Path(path).open("r", encoding="utf-8") as handle:
return cls(json.load(handle))
def _normalize(self, values: np.ndarray) -> np.ndarray:
mean = values.mean(axis=0, keepdims=True)
std = values.std(axis=0, keepdims=True)
return (values - mean) / np.maximum(std, self.eps)
def _pad_or_trim(self, values: np.ndarray) -> np.ndarray:
if values.shape[0] >= self.max_length:
return values[: self.max_length]
pad_rows = self.max_length - values.shape[0]
pad = np.full((pad_rows, values.shape[1]), self.padding_value, dtype=values.dtype)
return np.concatenate([values, pad], axis=0)
def __call__(self, sample: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
features: list[np.ndarray] = []
for name, spec in self.modalities.items():
if not spec.get("enabled", False):
continue
channels = int(spec["channels"])
values = np.asarray(sample.get(name, np.zeros((0, channels), dtype=np.float32)))
if values.ndim != 2 or values.shape[1] != channels:
raise ValueError(
f"Expected modality '{name}' to have shape [timesteps, {channels}], "
f"got {values.shape}."
)
normalized = self._normalize(values.astype(np.float32))
features.append(self._pad_or_trim(normalized))
if not features:
raise ValueError("No enabled modalities were provided.")
stacked = np.concatenate(features, axis=1)
attention_mask = (np.abs(stacked).sum(axis=1) > 0).astype(np.int64)
return {
"input_values": stacked,
"attention_mask": attention_mask,
}