LovecaSim / ai /training /train_vectorized.py
trioskosmos's picture
Upload ai/training/train_vectorized.py with huggingface_hub
77b2fc5 verified
import os
import sys
import time
# Immediate feedback
print(" [Init] Python process started. Loading libraries...")
print(" [Init] Loading Pytorch...", end="", flush=True)
import torch
import torch as th
import torch.nn.functional as F
print(" Done.")
print(" [Init] Loading Gymnasium & SB3...", end="", flush=True)
import glob
import warnings
import numpy as np
from gymnasium import spaces
from sb3_contrib import MaskablePPO
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
from stable_baselines3.common.utils import explained_variance
from tqdm import tqdm
print(" Done.")
# Filter Numba warning
warnings.filterwarnings("ignore", category=RuntimeWarning, message="nopython is set for njit")
# Ensure project root is in path
sys.path.append(os.getcwd())
print(" [Init] Loading LovecaSim Vector Engine...", end="", flush=True)
from ai.environments.vec_env_adapter import VectorEnvAdapter
from ai.utils.loveca_features_extractor import LovecaFeaturesExtractor
print(" Done.")
class TimeCheckpointCallback(BaseCallback):
"""
Save the model every N minutes.
"""
def __init__(self, save_freq_minutes: float, save_path: str, name_prefix: str, verbose: int = 0):
super().__init__(verbose)
self.save_freq_seconds = save_freq_minutes * 60
self.save_path = save_path
self.name_prefix = name_prefix
self.last_time_save = time.time()
def _on_step(self) -> bool:
if (time.time() - self.last_time_save) > self.save_freq_seconds:
save_path = os.path.join(self.save_path, f"{self.name_prefix}_time_auto")
self.model.save(save_path)
if self.verbose > 0:
print(f" [Save] Model auto-saved after 3 minutes to {save_path}")
self.last_time_save = time.time()
return True
class ModelSnapshotCallback(BaseCallback):
"""
Saves a 'Model Snapshot' every X minutes:
- model.zip
- verified_card_pool.json (Context)
- snapshot_meta.json (Architecture/Config)
"""
def __init__(self, save_freq_minutes: float, save_path: str, verbose=0):
super().__init__(verbose)
self.save_freq_minutes = save_freq_minutes
self.save_path = save_path
self.last_save_time = time.time()
# Ensure historiccheckpoints exists
os.makedirs("historiccheckpoints", exist_ok=True)
def _on_step(self) -> bool:
if time.time() - self.last_save_time > self.save_freq_minutes * 60:
self.last_save_time = time.time()
self._save_snapshot()
return True
def _save_snapshot(self):
timestamp = time.strftime("%Y%m%d_%H%M%S")
steps = self.num_timesteps
snapshot_name = f"{timestamp}_{steps}_steps"
snapshot_dir = os.path.join("historiccheckpoints", snapshot_name)
if self.verbose > 0:
print(f" [Snapshot] Saving to {snapshot_dir}...")
os.makedirs(snapshot_dir, exist_ok=True)
# 1. Save Model
model_path = os.path.join(snapshot_dir, "model.zip")
self.model.save(model_path)
# 2. Save Card Pool (Context)
try:
import shutil
shutil.copy("verified_card_pool.json", os.path.join(snapshot_dir, "verified_card_pool.json"))
except Exception as e:
print(f" [Snapshot] Warning: Could not copy card pool: {e}")
# 3. Save Metadata (Architecture)
meta = {
"timestamp": timestamp,
"timesteps": int(steps),
"obs_dim": int(self.model.observation_space.shape[0]),
"action_space_size": int(self.model.action_space.n),
"features": ["GlobalVolumes", "LiveZone", "Traits", "TurnNumber"],
"notes": "Generated by ModelSnapshotCallback",
}
try:
import json
with open(os.path.join(snapshot_dir, "snapshot_meta.json"), "w") as f:
json.dump(meta, f, indent=2)
except Exception as e:
print(f" [Snapshot] Warning: Could not save meta: {e}")
# 4. Limit to Last 5 Snapshots
self._prune_snapshots()
def _prune_snapshots(self):
root = os.path.dirname(self.save_path) # wait, save_path is "historiccheckpoints"?
# save_path passed in init is "historiccheckpoints" relative to cwd? Yes.
# But wait, self.save_path in init is used.
# Let's verify self.save_path from init
# It is "historiccheckpoints"
search_dir = self.save_path
if not os.path.exists(search_dir):
return
# Get list of directories
try:
subdirs = [
os.path.join(search_dir, d)
for d in os.listdir(search_dir)
if os.path.isdir(os.path.join(search_dir, d))
]
# Sort by creation time (oldest first)
subdirs.sort(key=os.path.getctime)
# Keep last 5
max_keep = 5
if len(subdirs) > max_keep:
to_remove = subdirs[:-max_keep]
import shutil
for d in to_remove:
try:
shutil.rmtree(d)
if self.verbose > 0:
print(f" [Snapshot] Pruned old snapshot: {d}")
except Exception as e:
print(f" [Snapshot] Warning: Failed to prune {d}: {e}")
except Exception as e:
print(f" [Snapshot] Warning: Pruning failed: {e}")
class DetailedStatusCallback(BaseCallback):
"""
Logs detailed phase information (Collection vs Optimization) and VRAM usage.
"""
def __init__(self, verbose=0):
super().__init__(verbose)
self.collection_start_time = 0.0
def _on_rollout_start(self) -> None:
"""
A rollout is the collection of environment steps.
"""
self.collection_start_time = time.time()
print(f"\n [Phase] Starting Rollout Collection (Steps: {self.model.n_steps})...")
def _on_rollout_end(self) -> None:
"""
This event is triggered before updating the policy.
"""
duration = time.time() - self.collection_start_time
n_envs = self.model.n_envs
n_steps = self.model.n_steps
total_steps = n_envs * n_steps
fps = total_steps / duration if duration > 0 else 0
print(f" [Phase] Collection Complete. Duration: {duration:.2f}s ({fps:.0f} FPS)")
# PPO optimization is about to start
print(f" [Phase] Starting PPO Optimization (Epochs: {self.model.n_epochs}, Batch: {self.model.batch_size})...")
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**3
reserved = torch.cuda.memory_reserved() / 1024**3
print(f" [VRAM] Allocated: {allocated:.2f} GB | Reserved: {reserved:.2f} GB")
print(" [Info] Optimization may take time if batch size is large. Please wait...")
def _on_step(self) -> bool:
return True
class TrainingStatsCallback(BaseCallback):
"""
Simple stats logging for Vectorized Training.
"""
def __init__(self, verbose=0):
super().__init__(verbose)
def _on_step(self) -> bool:
# Log win rate if available in infos
infos = self.locals.get("infos")
if infos:
# VectorEnv doesn't emit 'win_rate' in infos by default unless we add it
# But we can look for 'episode' keys
episodes = [i.get("episode") for i in infos if "episode" in i]
if episodes:
rew = np.mean([ep["r"] for ep in episodes])
length = np.mean([ep["l"] for ep in episodes])
self.logger.record("rollout/ep_rew_mean", rew)
self.logger.record("rollout/ep_len_mean", length)
return True
class ProgressMaskablePPO(MaskablePPO):
"""
MaskablePPO with a tqdm progress bar during the optimization phase.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.optimization_pbar = None
def train(self) -> None:
"""
Update policy using the currently gathered rollout buffer.
"""
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# Compute current clip range
clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator]
# Optional: clip range for the value function
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator]
entropy_losses = []
pg_losses, value_losses = [], []
clip_fractions = []
continue_training = True
# train for n_epochs epochs
# ADDED: Persistent TQDM Progress Bar
total_steps = self.n_epochs * (self.rollout_buffer.buffer_size // self.batch_size)
if self.optimization_pbar is None:
self.optimization_pbar = tqdm(total=total_steps, desc="Optimization", unit="batch", leave=True)
else:
self.optimization_pbar.reset(total=total_steps)
for epoch in range(self.n_epochs):
approx_kl_divs = []
# Do a complete pass on the rollout buffer
for rollout_data in self.rollout_buffer.get(self.batch_size):
actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action from float to long
actions = rollout_data.actions.long().flatten()
with th.cuda.amp.autocast(enabled=th.cuda.is_available()):
values, log_prob, entropy = self.policy.evaluate_actions(
rollout_data.observations,
actions,
action_masks=rollout_data.action_masks,
)
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
if self.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - rollout_data.old_log_prob)
# clipped surrogate loss
policy_loss_1 = advantages * ratio
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
# Logging
pg_losses.append(policy_loss.item())
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction)
if self.clip_range_vf is None:
# No clipping
values_pred = values
else:
# Clip the different between old and new value
# NOTE: this depends on the reward scaling
values_pred = rollout_data.old_values + th.clamp(
values - rollout_data.old_values, -clip_range_vf, clip_range_vf
)
# Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(rollout_data.returns, values_pred)
value_losses.append(value_loss.item())
# Entropy loss favor exploration
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -th.mean(-log_prob)
else:
entropy_loss = -th.mean(entropy)
entropy_losses.append(entropy_loss.item())
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
# Calculate approximate form of reverse KL Divergence for early stopping
# see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
# and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
# and Schulman blog: http://joschu.net/blog/kl-approx.html
with th.no_grad():
log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
approx_kl_divs.append(approx_kl_div)
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
continue_training = False
if self.verbose >= 1:
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
break
# Optimization step
self.policy.optimizer.zero_grad()
# AMP: Automatic Mixed Precision
# Check if scaler exists (backward compatibility)
if not hasattr(self, "scaler"):
self.scaler = th.cuda.amp.GradScaler(enabled=th.cuda.is_available())
# Backward pass
self.scaler.scale(loss).backward()
# Clip grad norm
self.scaler.unscale_(self.policy.optimizer)
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
# Optimizer step
self.scaler.step(self.policy.optimizer)
self.scaler.update()
# Update Progress Bar
self.optimization_pbar.update(1)
if not continue_training:
break
# Don't close, just leave it for the next reset
self._n_updates += self.n_epochs
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
# Logs
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
self.logger.record("train/value_loss", np.mean(value_losses))
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
self.logger.record("train/loss", loss.item())
self.logger.record("train/explained_variance", explained_var)
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/clip_range", clip_range)
if self.clip_range_vf is not None:
self.logger.record("train/clip_range_vf", clip_range_vf)
def _excluded_save_params(self) -> list[str]:
"""
Returns the names of the parameters that should be excluded from being saved.
"""
return super()._excluded_save_params() + ["optimization_pbar"]
def main():
print("========================================================")
print(" LovecaSim - STARTING VECTORIZED TRAINING (700k+ SPS) ")
print("========================================================")
# Configuration from Environment Variables
TOTAL_TIMESTEPS = int(os.getenv("TRAIN_STEPS", "100_000_000"))
BATCH_SIZE = int(os.getenv("TRAIN_BATCH_SIZE", "8192"))
NUM_ENVS = int(os.getenv("TRAIN_ENVS", "4096"))
N_STEPS = int(os.getenv("TRAIN_N_STEPS", "256"))
# Advanced Hyperparameters
ENT_COEF = float(os.getenv("ENT_COEF", "0.01"))
GAMMA = float(os.getenv("GAMMA", "0.99"))
GAE_LAMBDA = float(os.getenv("GAE_LAMBDA", "0.95"))
SAVE_PATH = "./checkpoints/vector/"
os.makedirs(SAVE_PATH, exist_ok=True)
# Log Hardware/Threading Config
omp_threads = os.getenv("OMP_NUM_THREADS", "Unset (All Cores)")
print(f" [Config] Batch Size: {BATCH_SIZE}")
print(f" [Config] Num Envs: {NUM_ENVS}")
print(f" [Config] N Steps: {N_STEPS}")
print(f" [Config] CPU Cores: {omp_threads}")
# 1. Create Vector Environment (Numba)
print(f" [Init] Creating {NUM_ENVS} parallel Numba environments...")
env = VectorEnvAdapter(num_envs=NUM_ENVS)
# --- WARMUP / COMPILATION ---
print(" [Init] Compiling Numba functions (Reset)... This may take 30s+")
env.reset()
print(" [Init] Compiling Numba functions (Step)... This may take 60s+")
# Perform a dummy step to force compilation of the massive step kernel
dummy_actions = np.zeros(NUM_ENVS, dtype=np.int32)
env.step(dummy_actions)
print(" [Init] Compilation complete! Starting training...")
# ----------------------------
# 2. Setup or Load PPO Agent
checkpoint_path = os.getenv("LOAD_CHECKPOINT", "")
# Auto-resolve "LATEST" or "AUTO"
force_restart = os.getenv("RESTART_TRAINING", "FALSE").upper() == "TRUE"
if force_restart:
print(" [Config] RESTART_TRAINING=TRUE. Ignoring checkpoints.")
checkpoint_path = ""
elif checkpoint_path.upper() in ["LATEST", "AUTO"]:
list_of_files = glob.glob(os.path.join(SAVE_PATH, "*.zip"))
if list_of_files:
checkpoint_path = max(list_of_files, key=os.path.getctime)
print(f" [Config] LOAD_CHECKPOINT='{os.getenv('LOAD_CHECKPOINT')}' -> Auto-resolved to: {checkpoint_path}")
else:
print(" [Config] LOAD_CHECKPOINT='LATEST' but no checkpoints found. Starting fresh.")
checkpoint_path = ""
model = None
if checkpoint_path and os.path.exists(checkpoint_path):
print(f" [Load] Scanning checkpoint: {checkpoint_path}")
try:
# Check dimensions before full load if possible, or load and check
temp_model = ProgressMaskablePPO.load(checkpoint_path, device="cpu")
model_obs_dim = temp_model.observation_space.shape[0]
env_obs_dim = env.observation_space.shape[0]
if model_obs_dim != env_obs_dim:
print(f" [Load] Dimension Mismatch! Model: {model_obs_dim}, Env: {env_obs_dim}")
print(f" [Load] Cannot resume training across eras. Starting FRESH {env_obs_dim}-dim model.")
model = None
else:
print(f" [Load] Dimensions match ({model_obs_dim}). Resuming training...")
model = ProgressMaskablePPO.load(
checkpoint_path,
env=env,
device="cuda" if torch.cuda.is_available() else "cpu",
custom_objects={
"learning_rate": float(os.getenv("LEARNING_RATE", "3e-4")),
"batch_size": BATCH_SIZE,
"n_epochs": int(os.getenv("NUM_EPOCHS", "4")),
},
)
reset_num_timesteps = False
print(" [Load] Success.")
except Exception as e:
print(f" [Error] Failed to load checkpoint: {e}")
print(" [Init] Falling back to fresh model...")
model = None
if model is None:
if checkpoint_path and not os.path.exists(checkpoint_path):
print(f" [Warning] Checkpoint file not found: {checkpoint_path}")
print(" [Init] Creating fresh ProgressMaskablePPO model...")
# Determine Policy Args
obs_mode_env = os.getenv("OBS_MODE", "STANDARD")
if obs_mode_env == "ATTENTION":
print(" [Init] Using LovecaFeaturesExtractor (Attention)")
policy_kwargs = dict(
features_extractor_class=LovecaFeaturesExtractor,
features_extractor_kwargs=dict(features_dim=256),
net_arch=[],
)
else:
policy_kwargs = dict(net_arch=[512, 512])
model = ProgressMaskablePPO(
"MlpPolicy",
env,
verbose=1,
learning_rate=float(os.getenv("LEARNING_RATE", "3e-4")),
n_steps=N_STEPS,
batch_size=BATCH_SIZE,
n_epochs=int(os.getenv("NUM_EPOCHS", "4")),
gamma=GAMMA,
gae_lambda=GAE_LAMBDA,
ent_coef=ENT_COEF,
tensorboard_log="./logs/vector_tensorboard/",
policy_kwargs=policy_kwargs,
)
reset_num_timesteps = True
print(f" [Init] PPO Model initialized. Device: {model.device}")
# 3. Callbacks
# Refactored: Callbacks moved to module level.
# Standard Checkpoint (Keep for compatibility/safety)
checkpoint_callback = CheckpointCallback(
save_freq=max(1, 1000000 // NUM_ENVS), save_path=SAVE_PATH, name_prefix="numba_ppo"
)
save_freq = float(os.getenv("SAVE_FREQ_MINS", "15.0"))
# Snapshot Callback (Replaces TimeCheckpointCallback)
snapshot_callback = ModelSnapshotCallback(
save_freq_minutes=save_freq,
save_path="historiccheckpoints",
verbose=1,
)
# Store OBS_MODE in snapshot meta
# (We need to update ModelSnapshotCallback logic or just trust env stores it?
# Ideally pass it to callback or update meta generation.
# Let's keep it simple: Environment tracks it.)
# 4. Train
print(" [Train] Starting training loop...")
print(f" [Train] Model Mode: {os.getenv('OBS_MODE', 'STANDARD')}")
print(f" [Train] Reset Timesteps: {reset_num_timesteps}")
print(" [Note] Press Ctrl+C to stop and force-save.")
# Generate a timestamped run name for TensorBoard
run_name = f"ProgressPPO_{time.strftime('%m%d_%H%M%S')}"
if not reset_num_timesteps:
run_name += "_RESUME"
try:
model.learn(
total_timesteps=TOTAL_TIMESTEPS,
callback=[
checkpoint_callback,
snapshot_callback,
TrainingStatsCallback(),
DetailedStatusCallback(),
], # Use Snapshot + DetailedStatus!
progress_bar=True,
reset_num_timesteps=reset_num_timesteps,
tb_log_name=run_name,
)
print(" [Train] Training finished.")
model.save(f"{SAVE_PATH}/final_model")
except KeyboardInterrupt:
print("\n [Train] Interrupted by user. Saving model...")
model.save(f"{SAVE_PATH}/interrupted_model")
# Trigger explicit snapshot on interrupt
snapshot_callback._save_snapshot()
print(" [Train] Model saved.")
except Exception as e:
print(f"\n [Error] Training crashed: {e}")
import traceback
traceback.print_exc()
emergency_save = os.path.join(SAVE_PATH, "crash_emergency")
model.save(emergency_save)
print(f" [Save] Crash emergency checkpoint saved to: {emergency_save}")
finally:
print(" [Done] Exiting gracefully.")
env.close()
if __name__ == "__main__":
main()