| """Library Drift Engine. |
| |
| Manages library version snapshots and triggers version upgrades during |
| training to create non-stationary verification. In simulation mode it |
| just tracks the current snapshot index — that index influences |
| breakage selection and is exposed in observations so the Repair Agent |
| can adapt. |
| |
| Also exposes Chojecki GVU's SNR computation |
| (https://arxiv.org/abs/2512.02731 Definition 4.4). |
| """ |
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass, field |
|
|
| DEFAULT_VERSION_SNAPSHOTS: list[dict[str, str]] = [ |
| {"transformers": "4.36.0", "datasets": "2.14.0", "trl": "0.7.0"}, |
| {"transformers": "4.40.0", "datasets": "2.18.0", "trl": "0.8.0"}, |
| {"transformers": "4.45.0", "datasets": "3.0.0", "trl": "0.10.0"}, |
| {"transformers": "4.50.0", "datasets": "3.2.0", "trl": "0.12.0"}, |
| ] |
|
|
|
|
| @dataclass |
| class LibraryDriftEngine: |
| snapshots: list[dict[str, str]] = field( |
| default_factory=lambda: list(DEFAULT_VERSION_SNAPSHOTS) |
| ) |
| current_index: int = 0 |
| drift_history: list[dict] = field(default_factory=list) |
|
|
| def current_versions(self) -> dict[str, str]: |
| return dict(self.snapshots[self.current_index]) |
|
|
| def maybe_drift(self, episode_num: int, drift_every: int = 50) -> bool: |
| if ( |
| episode_num > 0 |
| and episode_num % drift_every == 0 |
| and self.current_index < len(self.snapshots) - 1 |
| ): |
| prev = self.snapshots[self.current_index] |
| self.current_index += 1 |
| self.drift_history.append( |
| { |
| "episode": episode_num, |
| "from": prev, |
| "to": self.snapshots[self.current_index], |
| } |
| ) |
| return True |
| return False |
|
|
| def reset(self) -> None: |
| self.current_index = 0 |
| self.drift_history.clear() |
|
|
| @staticmethod |
| def compute_snr( |
| recent_held_out: list[float], recent_visible: list[float] |
| ) -> dict[str, float]: |
| """SNR per Chojecki GVU Def 4.4: SNR = mean(rewards)^2 / variance(rewards).""" |
|
|
| def snr(values: list[float]) -> float: |
| if len(values) < 2: |
| return 0.0 |
| mean = sum(values) / len(values) |
| var = sum((v - mean) ** 2 for v in values) / len(values) |
| return mean**2 / max(var, 1e-8) |
|
|
| return { |
| "snr_verifier": snr(recent_held_out), |
| "snr_generator": snr(recent_visible), |
| } |
|
|