LovecaSim / ai /training /train_optimized.py
trioskosmos's picture
Upload ai/training/train_optimized.py with huggingface_hub
5dfc8e2 verified
import os
import sys
import numpy as np
from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, CheckpointCallback
from stable_baselines3.common.monitor import Monitor
# Ensure project root is in path for local imports
if os.getcwd() not in sys.path:
sys.path.append(os.getcwd())
print(" [Heartbeat] train_optimized.py entry point reached.", flush=True)
import argparse
import torch
# If many workers are used, we keep intra-op threads low to avoid overhead
if int(os.getenv("TRAIN_CPUS", "4")) <= 4:
torch.set_num_threads(2)
else:
torch.set_num_threads(1)
# Fix for Windows DLL loading issues in subprocesses
if sys.platform == "win32":
# Add torch lib to DLL search path
torch_lib_path = os.path.join(os.path.dirname(torch.__file__), "lib")
if os.path.exists(torch_lib_path):
os.add_dll_directory(torch_lib_path)
# Ensure CUDA_PATH is in environ if found
if "CUDA_PATH" not in os.environ:
cuda_path = "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.2"
if os.path.exists(cuda_path):
os.environ["CUDA_PATH"] = cuda_path
os.environ["PATH"] = os.path.join(cuda_path, "bin") + os.pathsep + os.environ["PATH"]
# Import our environment
from functools import partial
from ai.batched_env import BatchedSubprocVecEnv
from ai.gym_env import LoveLiveCardGameEnv
class TrainingStatsCallback(BaseCallback):
"""Custom callback for logging win rates and illegal move stats from gym_env."""
def __init__(self, verbose=0):
super(TrainingStatsCallback, self).__init__(verbose)
def _on_step(self) -> bool:
if self.n_calls == 1:
print(" [Heartbeat] Training loop is active. First step reached!", flush=True)
infos = self.locals.get("infos")
if infos:
# 1. Capture Win Rate from custom env attribute (Legacy/Direct)
if len(infos) > 0 and "win_rate" in infos[0]:
avg_win_rate = np.mean([info.get("win_rate", 0) for info in infos])
self.logger.record("game/win_rate_legacy", avg_win_rate)
# 1b. Per-game Heartbeat
for info in infos:
if "episode" in info:
print(
f" [Heartbeat] Game completed! Reward: {info['episode']['r']:.2f} | Turns: {info['episode']['l']}",
flush=True,
)
# 2. Capture Episode Completion Stats
episode_infos = [info.get("episode") for info in infos if "episode" in info]
if episode_infos:
avg_reward = np.mean([ep["r"] for ep in episode_infos])
avg_turns = np.mean([ep["turn"] for ep in episode_infos])
win_count = sum(1 for ep in episode_infos if ep["win"])
win_rate = (win_count / len(episode_infos)) * 100
self.logger.record("game/avg_episode_reward", avg_reward)
self.logger.record("game/avg_win_turn", avg_turns)
self.logger.record("game/win_rate_rolling", win_rate)
# Periodic summary to terminal (More frequent for visibility)
if self.n_calls % 256 == 0:
print(
f" [Stats] Steps: {self.num_timesteps} | Win Rate: {win_rate:.1f}% | Avg Reward: {avg_reward:.2f} | Avg Turn: {avg_turns:.1f}",
flush=True,
)
return True
class SaveOnBestWinRateCallback(BaseCallback):
"""Callback to save the model when win rate reaches a new peak."""
def __init__(self, check_freq: int, save_path: str, verbose=1):
super(SaveOnBestWinRateCallback, self).__init__(verbose)
self.check_freq = check_freq
self.save_path = save_path
self.best_win_rate = -np.inf
self.min_win_rate_threshold = 30.0 # Only save 'best' if above 30% to avoid early noise
def _init_callback(self) -> None:
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)
def _on_step(self) -> bool:
if self.n_calls % self.check_freq == 0:
infos = self.locals.get("infos")
if infos:
avg_win_rate = np.mean([info.get("win_rate", 0) for info in infos])
if avg_win_rate > self.best_win_rate and avg_win_rate > self.min_win_rate_threshold:
self.best_win_rate = avg_win_rate
if self.verbose > 0:
print(
f" [Saving] New Best Win Rate: {avg_win_rate:.1f}%! Progressing towards big moment...",
flush=True,
)
self.model.save(os.path.join(self.save_path, "best_win_rate_model"))
return True
class SelfPlayUpdateCallback(BaseCallback):
"""Callback to save the model for self-play opponents."""
def __init__(self, update_freq: int, save_path: str, verbose=0):
super(SelfPlayUpdateCallback, self).__init__(verbose)
self.update_freq = update_freq
self.save_path = save_path
def _init_callback(self) -> None:
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)
def _on_step(self) -> bool:
if self.n_calls % self.update_freq == 0:
if self.verbose > 0:
print(" [Self-Play] Updating opponent model...", flush=True)
self.model.save(os.path.join(self.save_path, "self_play_opponent"))
return True
def create_env(rank, usage=0.5, deck_type="random_verified", opponent_type="random"):
env = LoveLiveCardGameEnv(target_cpu_usage=usage, deck_type=deck_type, opponent_type=opponent_type)
env = Monitor(env)
env = ActionMasker(env, lambda e: e.unwrapped.action_masks())
# Seed for diversity across workers
env.reset(seed=42 + rank)
return env
def train():
# 1. Hardware Constraints Setup
num_cpu = int(os.getenv("TRAIN_CPUS", "4"))
usage = float(os.getenv("TRAIN_USAGE", "0.5"))
deck_type = os.getenv("TRAIN_DECK", "random_verified")
gpu_usage = float(os.getenv("TRAIN_GPU_USAGE", "0.7"))
batch_size = int(os.getenv("TRAIN_BATCH_SIZE", "256"))
n_epochs = int(os.getenv("TRAIN_EPOCHS", "10"))
n_steps = int(os.getenv("TRAIN_STEPS", "2048"))
opponent_type = os.getenv("TRAIN_OPPONENT", "random")
if torch.cuda.is_available() and gpu_usage < 1.0:
try:
print(f"Limiting GPU memory usage to {int(gpu_usage * 100)}%...", flush=True)
torch.cuda.set_per_process_memory_fraction(gpu_usage)
except Exception as e:
print(f"Warning: Could not set GPU memory fraction: {e}. Proceeding without limit.", flush=True)
print(
f"Initializing {num_cpu} parallel environments ({deck_type}) with opponent {opponent_type} and {int(usage * 100)}% per-core throttle...",
flush=True,
)
# Create Vectorized Environment
try:
# Optimization: Workers always use "random" internally because BatchedSubprocVecEnv intercepts and runs the real opponent
# This prevents workers from importing torch/sb3 and saves GBs of RAM.
env_fns = [
partial(create_env, rank=i, usage=usage, deck_type=deck_type, opponent_type="random")
for i in range(num_cpu)
]
# Use our new Batched inference environment
opponent_path = os.path.join(os.getcwd(), "checkpoints", "self_play_opponent.zip")
env = BatchedSubprocVecEnv(env_fns, opponent_model_path=opponent_path if opponent_type == "self_play" else None)
print("Batched workers initialized! Starting training loop...", flush=True)
except Exception as e:
print(f"CRITICAL ERROR during worker initialization: {e}", flush=True)
import traceback
traceback.print_exc()
return
# 2. Model Configuration
load_path = os.getenv("LOAD_MODEL")
model = None
if load_path and os.path.exists(load_path):
try:
print(f" [LOAD] Loading existing model from {load_path}...", flush=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = MaskablePPO.load(load_path, env=env, device=device)
print(" [LOAD] Model loaded successfully.", flush=True)
except ValueError as val_err:
if "Observation spaces do not match" in str(val_err):
print(
f" [WARNING] Checkpoint {load_path} has incompatible observation space (likely from an older engine version).",
flush=True,
)
print(" [WARNING] Skipping load and starting fresh to maintain stability.", flush=True)
model = None # Force fresh start
else:
raise val_err
except Exception as load_err:
print(f" [CRITICAL ERROR] Failed to load checkpoint: {load_err}", flush=True)
import traceback
traceback.print_exc()
env.close()
sys.exit(1)
if model is None:
print(" [INFO] Initializing fresh MaskablePPO model...", flush=True)
model = MaskablePPO(
"MlpPolicy",
env,
verbose=0,
gamma=0.99,
learning_rate=3e-4,
n_steps=n_steps,
batch_size=batch_size,
n_epochs=n_epochs,
tensorboard_log="./logs/ppo_tensorboard/",
device="cuda",
)
# NEW: Dry run support
parser = argparse.ArgumentParser()
parser.add_argument("--dry-run", action="store_true", help="Initialize and exit")
args, unknown = parser.parse_known_args()
if args.dry_run:
print(" [Dry Run] Workers initialized successfully. Exiting.", flush=True)
env.close()
return
print(f"Starting Training on {num_cpu} workers (Usage: {usage * 100}%)...", flush=True)
# Checkpoint Callback
checkpoint_callback = CheckpointCallback(
save_freq=max(1, 200000 // num_cpu), save_path="./checkpoints/", name_prefix="lovelive_ppo_checkpoint"
)
# 3. Learning Loop
stats_callback = TrainingStatsCallback()
best_rate_callback = SaveOnBestWinRateCallback(check_freq=1024, save_path="./checkpoints/")
self_play_callback = SelfPlayUpdateCallback(update_freq=20000, save_path="./checkpoints/")
callback_list = CallbackList([checkpoint_callback, stats_callback, best_rate_callback, self_play_callback])
try:
print(f"Starting Long-Running Training on {num_cpu} workers (Usage: {usage * 100}%)...")
model.learn(total_timesteps=2_000_000_000, progress_bar=False, callback=callback_list)
# Save Final Model
os.makedirs("checkpoints", exist_ok=True)
model.save("checkpoints/lovelive_ppo_optimized")
print("Training Complete. Model Saved.")
except KeyboardInterrupt:
print("\nTraining interrupted. Saving current progress...")
model.save("checkpoints/lovelive_ppo_interrupted")
finally:
env.close()
if __name__ == "__main__":
train()