Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| train.py β PPO training loop for the High-Frequency Risk Compliance Auditor. | |
| ============================================================================= | |
| Wraps the FinAuditorEnvironment in a Gymnasium-compatible adapter and trains | |
| a PPO agent using Stable Baselines3. | |
| NaN-collapse fixes applied (see inline comments): | |
| 1. Observation space bounded to [0.0, 1.0] instead of Β±inf. | |
| 2. Features clipped to [0.0, 1.0] in _process_obs to prevent gradient explosion. | |
| 3. Environment is terminated via done=True (set in fin_auditor_environment.py) | |
| so PPO can compute GAE advantages without infinite truncation. | |
| 4. Density-based reward in the environment removes the sparse penalty dead zone. | |
| Usage: | |
| python train.py | |
| """ | |
| import os | |
| import sys | |
| import numpy as np | |
| import gymnasium as gym | |
| from gymnasium import spaces | |
| # Stable Baselines3 for PPO | |
| from stable_baselines3 import PPO | |
| from stable_baselines3.common.env_checker import check_env | |
| from stable_baselines3.common.callbacks import CheckpointCallback | |
| # Add project root so hft_auditor .so is importable | |
| _ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| if _ROOT not in sys.path: | |
| sys.path.insert(0, _ROOT) | |
| from server.fin_auditor_environment import FinAuditorEnvironment | |
| from models import AuditorAction | |
| # ββ Hyperparameters βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| N_FEATURES = 4 # [time_elapsed, price_delta, missing_frequency, risk_score] | |
| MAX_TRADES = 40 # maximum anomalies per step (== INGEST_CHUNK_SIZE) | |
| TOTAL_TIMESTEPS = 100_000 | |
| SAVE_FREQ = 5_000 | |
| LOG_DIR = "./logs/" | |
| SAVE_PATH = os.path.join(LOG_DIR, "rl_model") | |
| class GymnasiumFinAuditorEnv(gym.Env): | |
| """ | |
| Gymnasium wrapper around FinAuditorEnvironment. | |
| Observation: flat float32 array of shape (MAX_TRADES * N_FEATURES,) | |
| Values clipped to [0.0, 1.0] to prevent NaN gradients. | |
| Action: MultiDiscrete([2] * MAX_TRADES) | |
| 0=PASS, 1=FLAG per trade slot. | |
| """ | |
| metadata = {"render_modes": []} | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._env = FinAuditorEnvironment() | |
| obs_size = MAX_TRADES * N_FEATURES | |
| # Normalization and clipping are now handled by VecNormalize wrapper in main() | |
| self.observation_space = spaces.Box( | |
| low=-np.inf, | |
| high=np.inf, | |
| shape=(obs_size,), | |
| dtype=np.float32, | |
| ) | |
| # One discrete decision per trade slot | |
| self.action_space = spaces.MultiDiscrete([2] * MAX_TRADES) | |
| def _process_obs(self, features: list[list[float]]) -> np.ndarray: | |
| """Flatten the anomaly matrix into a fixed-size float32 vector.""" | |
| flat = np.zeros(MAX_TRADES * N_FEATURES, dtype=np.float32) | |
| for i, row in enumerate(features[:MAX_TRADES]): | |
| for j, val in enumerate(row[:N_FEATURES]): | |
| flat[i * N_FEATURES + j] = float(val) | |
| # Padding and normalization are handled by the vectorized environment wrapper. | |
| return flat | |
| def reset( | |
| self, | |
| *, | |
| seed: int | None = None, | |
| options: dict | None = None, | |
| ) -> tuple[np.ndarray, dict]: | |
| super().reset(seed=seed) | |
| obs_obj = self._env.reset() | |
| obs = self._process_obs(obs_obj.features) | |
| return obs, {} | |
| def step( | |
| self, action: np.ndarray | |
| ) -> tuple[np.ndarray, float, bool, bool, dict]: | |
| decisions = action.tolist() # MultiDiscrete β Python list of ints | |
| action_obj = AuditorAction(decisions=decisions) | |
| obs_obj = self._env.step(action_obj) | |
| obs = self._process_obs(obs_obj.features) | |
| reward = float(obs_obj.reward) if obs_obj.reward is not None else 0.0 | |
| done = bool(obs_obj.done) # True when step_count >= _MAX_EPISODE_STEPS | |
| return obs, reward, done, False, {} | |
| def render(self) -> None: | |
| pass | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Training entrypoint | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main() -> None: | |
| os.makedirs(LOG_DIR, exist_ok=True) | |
| from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize | |
| env = GymnasiumFinAuditorEnv() | |
| # Sanity-check the raw environment before vectorization | |
| print("[TRAIN] Running Gymnasium environment check...") | |
| check_env(env, warn=True) | |
| print("[TRAIN] Environment check passed.\n") | |
| # WRAP: Use DummyVecEnv and VecNormalize for robust training. | |
| # SB3 requires vectorized environments for several wrappers. | |
| env = DummyVecEnv([lambda: env]) | |
| env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0) | |
| checkpoint_callback = CheckpointCallback( | |
| save_freq=SAVE_FREQ, | |
| save_path=LOG_DIR, | |
| name_prefix="rl_model", | |
| verbose=1, | |
| ) | |
| model = PPO( | |
| "MlpPolicy", | |
| env, | |
| verbose=1, | |
| device="cpu", | |
| n_steps=2048, # rollout buffer length per env per update | |
| batch_size=64, | |
| n_epochs=10, | |
| gamma=0.99, | |
| gae_lambda=0.95, | |
| clip_range=0.2, | |
| ent_coef=0.01, # mild entropy bonus for exploration | |
| vf_coef=0.5, | |
| max_grad_norm=0.5, # gradient clipping prevents NaN proliferation | |
| tensorboard_log=LOG_DIR, | |
| ) | |
| print(f"[TRAIN] Starting PPO training for {TOTAL_TIMESTEPS} timesteps...\n") | |
| try: | |
| model.learn( | |
| total_timesteps=TOTAL_TIMESTEPS, | |
| callback=checkpoint_callback, | |
| progress_bar=True, | |
| ) | |
| except KeyboardInterrupt: | |
| print("\n[TRAIN] Training interrupted by user.") | |
| final_path = os.path.join(LOG_DIR, "ppo_fin_auditor_final") | |
| model.save(final_path) | |
| print(f"\n[TRAIN] Model saved to: {final_path}.zip") | |
| if __name__ == "__main__": | |
| main() | |