Rayugacodes's picture
Training pipeline scripts
32a197f verified
"""
KernelX Intelligence Layer — RL Environment (OpenEnv structure)
Provides reset/step interface for training the Strategist policy via GRPO.
Replays recorded transitions from the preprocessed JSONL dataset and
computes multi-objective rewards.
"""
import json
import random
from dataclasses import dataclass, field
from typing import List, Tuple
from .rewards import RewardComputer
@dataclass
class KernelState:
"""Observation wrapper for the RL environment."""
features: List[float] # active features (10D after preprocessing)
pid: int
cpu: int
timestep: int
prev_action: float
@dataclass
class KernelAction:
"""Action output from the Strategist."""
value: float # scheduling weight in [-1.0, 1.0]
class KernelSchedulerEnv:
"""Offline RL environment that replays recorded kernel transitions.
Each episode starts at a random position in the dataset and runs for
max_steps transitions. The reward is computed from the multi-objective
RewardComputer.
"""
def __init__(
self,
data_path: str = "training/data/train.jsonl",
max_steps: int = 10,
alpha: float = 1.0,
beta: float = 2.0,
gamma: float = 0.5,
):
self.records = [json.loads(l) for l in open(data_path) if l.strip()]
self.max_steps = max_steps
self.reward_computer = RewardComputer(alpha=alpha, beta=beta, gamma=gamma)
# Episode state
self.timestep = 0
self.current_idx = 0
self.prev_action = 0.0
if len(self.records) < max_steps + 1:
raise ValueError(
f"Dataset has {len(self.records)} records but max_steps={max_steps} "
f"requires at least {max_steps + 1}"
)
def reset(self) -> KernelState:
"""Start a fresh episode from a random point in the dataset."""
self.timestep = 0
self.current_idx = random.randint(0, len(self.records) - self.max_steps - 1)
self.prev_action = 0.0
return self._get_state()
def step(self, action: KernelAction) -> Tuple[KernelState, dict, bool]:
"""Apply action, compute reward, advance to next state.
Returns:
next_state: The new KernelState after the transition
reward_breakdown: Dict with 'total' and per-component rewards
done: Whether the episode has ended
"""
current = self.records[self.current_idx + self.timestep]
next_idx = self.current_idx + self.timestep + 1
next_rec = self.records[next_idx] if next_idx < len(self.records) else current
reward_breakdown = self.reward_computer.compute_total(
state=current["state"],
action=action,
prev_action=self.prev_action,
next_state=next_rec["state"],
)
self.timestep += 1
self.prev_action = action.value
done = self.timestep >= self.max_steps
return self._get_state(), reward_breakdown, done
def _get_state(self) -> KernelState:
"""Read the current state from the dataset."""
rec = self.records[self.current_idx + self.timestep]
return KernelState(
features=rec["state"],
pid=rec["pid"],
cpu=rec["cpu"],
timestep=self.timestep,
prev_action=self.prev_action,
)
def simulate(self, state_features: list, action_value: float) -> list:
"""Lightweight next-state lookup for reward_fn during GRPO.
Finds the nearest recorded state in the dataset and returns
its recorded next_state. This is a simple approximation;
the World Model provides higher-fidelity simulation.
"""
import numpy as np
state_arr = np.array(state_features)
best_dist = float("inf")
best_next = state_features # fallback
# Sample a subset to keep this fast
sample_size = min(500, len(self.records))
indices = random.sample(range(len(self.records)), sample_size)
for idx in indices:
rec = self.records[idx]
dist = float(np.linalg.norm(state_arr - np.array(rec["state"])))
if dist < best_dist:
best_dist = dist
best_next = rec["next_state"]
return best_next