Spaces:
Running
Running
| import os | |
| import sys | |
| import numpy as np | |
| from sb3_contrib import MaskablePPO | |
| from sb3_contrib.common.wrappers import ActionMasker | |
| from stable_baselines3.common.callbacks import BaseCallback, CallbackList, CheckpointCallback | |
| from stable_baselines3.common.monitor import Monitor | |
| # Ensure project root is in path for local imports | |
| if os.getcwd() not in sys.path: | |
| sys.path.append(os.getcwd()) | |
| print(" [Heartbeat] train_optimized.py entry point reached.", flush=True) | |
| import argparse | |
| import torch | |
| # If many workers are used, we keep intra-op threads low to avoid overhead | |
| if int(os.getenv("TRAIN_CPUS", "4")) <= 4: | |
| torch.set_num_threads(2) | |
| else: | |
| torch.set_num_threads(1) | |
| # Fix for Windows DLL loading issues in subprocesses | |
| if sys.platform == "win32": | |
| # Add torch lib to DLL search path | |
| torch_lib_path = os.path.join(os.path.dirname(torch.__file__), "lib") | |
| if os.path.exists(torch_lib_path): | |
| os.add_dll_directory(torch_lib_path) | |
| # Ensure CUDA_PATH is in environ if found | |
| if "CUDA_PATH" not in os.environ: | |
| cuda_path = "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.2" | |
| if os.path.exists(cuda_path): | |
| os.environ["CUDA_PATH"] = cuda_path | |
| os.environ["PATH"] = os.path.join(cuda_path, "bin") + os.pathsep + os.environ["PATH"] | |
| # Import our environment | |
| from functools import partial | |
| from ai.batched_env import BatchedSubprocVecEnv | |
| from ai.gym_env import LoveLiveCardGameEnv | |
| class TrainingStatsCallback(BaseCallback): | |
| """Custom callback for logging win rates and illegal move stats from gym_env.""" | |
| def __init__(self, verbose=0): | |
| super(TrainingStatsCallback, self).__init__(verbose) | |
| def _on_step(self) -> bool: | |
| if self.n_calls == 1: | |
| print(" [Heartbeat] Training loop is active. First step reached!", flush=True) | |
| infos = self.locals.get("infos") | |
| if infos: | |
| # 1. Capture Win Rate from custom env attribute (Legacy/Direct) | |
| if len(infos) > 0 and "win_rate" in infos[0]: | |
| avg_win_rate = np.mean([info.get("win_rate", 0) for info in infos]) | |
| self.logger.record("game/win_rate_legacy", avg_win_rate) | |
| # 1b. Per-game Heartbeat | |
| for info in infos: | |
| if "episode" in info: | |
| print( | |
| f" [Heartbeat] Game completed! Reward: {info['episode']['r']:.2f} | Turns: {info['episode']['l']}", | |
| flush=True, | |
| ) | |
| # 2. Capture Episode Completion Stats | |
| episode_infos = [info.get("episode") for info in infos if "episode" in info] | |
| if episode_infos: | |
| avg_reward = np.mean([ep["r"] for ep in episode_infos]) | |
| avg_turns = np.mean([ep["turn"] for ep in episode_infos]) | |
| win_count = sum(1 for ep in episode_infos if ep["win"]) | |
| win_rate = (win_count / len(episode_infos)) * 100 | |
| self.logger.record("game/avg_episode_reward", avg_reward) | |
| self.logger.record("game/avg_win_turn", avg_turns) | |
| self.logger.record("game/win_rate_rolling", win_rate) | |
| # Periodic summary to terminal (More frequent for visibility) | |
| if self.n_calls % 256 == 0: | |
| print( | |
| f" [Stats] Steps: {self.num_timesteps} | Win Rate: {win_rate:.1f}% | Avg Reward: {avg_reward:.2f} | Avg Turn: {avg_turns:.1f}", | |
| flush=True, | |
| ) | |
| return True | |
| class SaveOnBestWinRateCallback(BaseCallback): | |
| """Callback to save the model when win rate reaches a new peak.""" | |
| def __init__(self, check_freq: int, save_path: str, verbose=1): | |
| super(SaveOnBestWinRateCallback, self).__init__(verbose) | |
| self.check_freq = check_freq | |
| self.save_path = save_path | |
| self.best_win_rate = -np.inf | |
| self.min_win_rate_threshold = 30.0 # Only save 'best' if above 30% to avoid early noise | |
| def _init_callback(self) -> None: | |
| if self.save_path is not None: | |
| os.makedirs(self.save_path, exist_ok=True) | |
| def _on_step(self) -> bool: | |
| if self.n_calls % self.check_freq == 0: | |
| infos = self.locals.get("infos") | |
| if infos: | |
| avg_win_rate = np.mean([info.get("win_rate", 0) for info in infos]) | |
| if avg_win_rate > self.best_win_rate and avg_win_rate > self.min_win_rate_threshold: | |
| self.best_win_rate = avg_win_rate | |
| if self.verbose > 0: | |
| print( | |
| f" [Saving] New Best Win Rate: {avg_win_rate:.1f}%! Progressing towards big moment...", | |
| flush=True, | |
| ) | |
| self.model.save(os.path.join(self.save_path, "best_win_rate_model")) | |
| return True | |
| class SelfPlayUpdateCallback(BaseCallback): | |
| """Callback to save the model for self-play opponents.""" | |
| def __init__(self, update_freq: int, save_path: str, verbose=0): | |
| super(SelfPlayUpdateCallback, self).__init__(verbose) | |
| self.update_freq = update_freq | |
| self.save_path = save_path | |
| def _init_callback(self) -> None: | |
| if self.save_path is not None: | |
| os.makedirs(self.save_path, exist_ok=True) | |
| def _on_step(self) -> bool: | |
| if self.n_calls % self.update_freq == 0: | |
| if self.verbose > 0: | |
| print(" [Self-Play] Updating opponent model...", flush=True) | |
| self.model.save(os.path.join(self.save_path, "self_play_opponent")) | |
| return True | |
| def create_env(rank, usage=0.5, deck_type="random_verified", opponent_type="random"): | |
| env = LoveLiveCardGameEnv(target_cpu_usage=usage, deck_type=deck_type, opponent_type=opponent_type) | |
| env = Monitor(env) | |
| env = ActionMasker(env, lambda e: e.unwrapped.action_masks()) | |
| # Seed for diversity across workers | |
| env.reset(seed=42 + rank) | |
| return env | |
| def train(): | |
| # 1. Hardware Constraints Setup | |
| num_cpu = int(os.getenv("TRAIN_CPUS", "4")) | |
| usage = float(os.getenv("TRAIN_USAGE", "0.5")) | |
| deck_type = os.getenv("TRAIN_DECK", "random_verified") | |
| gpu_usage = float(os.getenv("TRAIN_GPU_USAGE", "0.7")) | |
| batch_size = int(os.getenv("TRAIN_BATCH_SIZE", "256")) | |
| n_epochs = int(os.getenv("TRAIN_EPOCHS", "10")) | |
| n_steps = int(os.getenv("TRAIN_STEPS", "2048")) | |
| opponent_type = os.getenv("TRAIN_OPPONENT", "random") | |
| if torch.cuda.is_available() and gpu_usage < 1.0: | |
| try: | |
| print(f"Limiting GPU memory usage to {int(gpu_usage * 100)}%...", flush=True) | |
| torch.cuda.set_per_process_memory_fraction(gpu_usage) | |
| except Exception as e: | |
| print(f"Warning: Could not set GPU memory fraction: {e}. Proceeding without limit.", flush=True) | |
| print( | |
| f"Initializing {num_cpu} parallel environments ({deck_type}) with opponent {opponent_type} and {int(usage * 100)}% per-core throttle...", | |
| flush=True, | |
| ) | |
| # Create Vectorized Environment | |
| try: | |
| # Optimization: Workers always use "random" internally because BatchedSubprocVecEnv intercepts and runs the real opponent | |
| # This prevents workers from importing torch/sb3 and saves GBs of RAM. | |
| env_fns = [ | |
| partial(create_env, rank=i, usage=usage, deck_type=deck_type, opponent_type="random") | |
| for i in range(num_cpu) | |
| ] | |
| # Use our new Batched inference environment | |
| opponent_path = os.path.join(os.getcwd(), "checkpoints", "self_play_opponent.zip") | |
| env = BatchedSubprocVecEnv(env_fns, opponent_model_path=opponent_path if opponent_type == "self_play" else None) | |
| print("Batched workers initialized! Starting training loop...", flush=True) | |
| except Exception as e: | |
| print(f"CRITICAL ERROR during worker initialization: {e}", flush=True) | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # 2. Model Configuration | |
| load_path = os.getenv("LOAD_MODEL") | |
| model = None | |
| if load_path and os.path.exists(load_path): | |
| try: | |
| print(f" [LOAD] Loading existing model from {load_path}...", flush=True) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = MaskablePPO.load(load_path, env=env, device=device) | |
| print(" [LOAD] Model loaded successfully.", flush=True) | |
| except ValueError as val_err: | |
| if "Observation spaces do not match" in str(val_err): | |
| print( | |
| f" [WARNING] Checkpoint {load_path} has incompatible observation space (likely from an older engine version).", | |
| flush=True, | |
| ) | |
| print(" [WARNING] Skipping load and starting fresh to maintain stability.", flush=True) | |
| model = None # Force fresh start | |
| else: | |
| raise val_err | |
| except Exception as load_err: | |
| print(f" [CRITICAL ERROR] Failed to load checkpoint: {load_err}", flush=True) | |
| import traceback | |
| traceback.print_exc() | |
| env.close() | |
| sys.exit(1) | |
| if model is None: | |
| print(" [INFO] Initializing fresh MaskablePPO model...", flush=True) | |
| model = MaskablePPO( | |
| "MlpPolicy", | |
| env, | |
| verbose=0, | |
| gamma=0.99, | |
| learning_rate=3e-4, | |
| n_steps=n_steps, | |
| batch_size=batch_size, | |
| n_epochs=n_epochs, | |
| tensorboard_log="./logs/ppo_tensorboard/", | |
| device="cuda", | |
| ) | |
| # NEW: Dry run support | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--dry-run", action="store_true", help="Initialize and exit") | |
| args, unknown = parser.parse_known_args() | |
| if args.dry_run: | |
| print(" [Dry Run] Workers initialized successfully. Exiting.", flush=True) | |
| env.close() | |
| return | |
| print(f"Starting Training on {num_cpu} workers (Usage: {usage * 100}%)...", flush=True) | |
| # Checkpoint Callback | |
| checkpoint_callback = CheckpointCallback( | |
| save_freq=max(1, 200000 // num_cpu), save_path="./checkpoints/", name_prefix="lovelive_ppo_checkpoint" | |
| ) | |
| # 3. Learning Loop | |
| stats_callback = TrainingStatsCallback() | |
| best_rate_callback = SaveOnBestWinRateCallback(check_freq=1024, save_path="./checkpoints/") | |
| self_play_callback = SelfPlayUpdateCallback(update_freq=20000, save_path="./checkpoints/") | |
| callback_list = CallbackList([checkpoint_callback, stats_callback, best_rate_callback, self_play_callback]) | |
| try: | |
| print(f"Starting Long-Running Training on {num_cpu} workers (Usage: {usage * 100}%)...") | |
| model.learn(total_timesteps=2_000_000_000, progress_bar=False, callback=callback_list) | |
| # Save Final Model | |
| os.makedirs("checkpoints", exist_ok=True) | |
| model.save("checkpoints/lovelive_ppo_optimized") | |
| print("Training Complete. Model Saved.") | |
| except KeyboardInterrupt: | |
| print("\nTraining interrupted. Saving current progress...") | |
| model.save("checkpoints/lovelive_ppo_interrupted") | |
| finally: | |
| env.close() | |
| if __name__ == "__main__": | |
| train() | |