forgeenv-source / forgeenv /drift /library_drift_engine.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""Library Drift Engine.
Manages library version snapshots and triggers version upgrades during
training to create non-stationary verification. In simulation mode it
just tracks the current snapshot index — that index influences
breakage selection and is exposed in observations so the Repair Agent
can adapt.
Also exposes Chojecki GVU's SNR computation
(https://arxiv.org/abs/2512.02731 Definition 4.4).
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
DEFAULT_VERSION_SNAPSHOTS: list[dict[str, str]] = [
{"transformers": "4.36.0", "datasets": "2.14.0", "trl": "0.7.0"},
{"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.8.0"},
{"transformers": "4.45.0", "datasets": "3.0.0", "trl": "0.10.0"},
{"transformers": "4.50.0", "datasets": "3.2.0", "trl": "0.12.0"},
]
@dataclass
class LibraryDriftEngine:
snapshots: list[dict[str, str]] = field(
default_factory=lambda: list(DEFAULT_VERSION_SNAPSHOTS)
)
current_index: int = 0
drift_history: list[dict] = field(default_factory=list)
def current_versions(self) -> dict[str, str]:
return dict(self.snapshots[self.current_index])
def maybe_drift(self, episode_num: int, drift_every: int = 50) -> bool:
if (
episode_num > 0
and episode_num % drift_every == 0
and self.current_index < len(self.snapshots) - 1
):
prev = self.snapshots[self.current_index]
self.current_index += 1
self.drift_history.append(
{
"episode": episode_num,
"from": prev,
"to": self.snapshots[self.current_index],
}
)
return True
return False
def reset(self) -> None:
self.current_index = 0
self.drift_history.clear()
@staticmethod
def compute_snr(
recent_held_out: list[float], recent_visible: list[float]
) -> dict[str, float]:
"""SNR per Chojecki GVU Def 4.4: SNR = mean(rewards)^2 / variance(rewards)."""
def snr(values: list[float]) -> float:
if len(values) < 2:
return 0.0
mean = sum(values) / len(values)
var = sum((v - mean) ** 2 for v in values) / len(values)
return mean**2 / max(var, 1e-8)
return {
"snr_verifier": snr(recent_held_out),
"snr_generator": snr(recent_visible),
}