| """AE Manager - loads trained MaskablePPO and returns actions for Bomberman.""" |
|
|
| import os |
| import sys |
| import numpy as np |
| from sb3_contrib import MaskablePPO |
|
|
| |
| for p in [ |
| os.path.join(os.path.dirname(__file__), "..", "til-26-ae"), |
| "/app/til-26-ae-repo/til-26-ae", |
| "til-26-ae", |
| ]: |
| if os.path.isdir(p) and os.path.isfile(os.path.join(p, "til_environment", "bomberman_env.py")): |
| sys.path.insert(0, p) |
| break |
|
|
|
|
| class AEManager: |
| """Loads a trained MaskablePPO model and serves inference for Bomberman.""" |
|
|
| def __init__(self): |
| self.model = None |
| self._obs_size = None |
| |
| candidates = [ |
| os.environ.get("MODEL_PATH", ""), |
| os.path.join(os.path.dirname(__file__), "..", "phase1_final.zip"), |
| os.path.join(os.path.dirname(__file__), "..", "phase3_final.zip"), |
| "/app/data/phase3_final.zip", |
| "/app/data/phase2_final.zip", |
| "/app/data/phase1_final.zip", |
| ] |
| for path in candidates: |
| if path and os.path.isfile(path): |
| try: |
| self.model = MaskablePPO.load(path) |
| print(f"[AE Manager] Loaded model from {path}") |
| break |
| except Exception as e: |
| print(f"[AE Manager] Failed to load {path}: {e}") |
| if self.model is None: |
| print("[AE Manager] No trained model found -- will return random valid actions.") |
|
|
| @staticmethod |
| def _flatten_obs(obs_dict): |
| """Flatten observation dict into the vector used during training.""" |
| return np.concatenate([ |
| np.array(obs_dict["agent_viewcone"]).flatten(), |
| np.array(obs_dict["base_viewcone"]).flatten(), |
| np.array([obs_dict["direction"]], dtype=np.float32), |
| np.array(obs_dict["location"]).flatten().astype(np.float32), |
| np.array(obs_dict["base_location"]).flatten().astype(np.float32), |
| np.array(obs_dict["health"]).flatten().astype(np.float32), |
| np.array([obs_dict["frozen_ticks"]], dtype=np.float32), |
| np.array(obs_dict["base_health"]).flatten().astype(np.float32), |
| np.array(obs_dict["team_resources"]).flatten().astype(np.float32), |
| np.array([obs_dict["team_bombs"]], dtype=np.float32), |
| np.array([obs_dict["step"]], dtype=np.float32), |
| ], dtype=np.float32) |
|
|
| def ae(self, observation: dict) -> int: |
| """Get action from observation dict.""" |
| if self.model is None: |
| |
| mask = np.array(observation.get("action_mask", [1]*6), dtype=bool) |
| valid = np.where(mask)[0] |
| return int(np.random.choice(valid)) if len(valid) > 0 else 4 |
|
|
| obs_vec = self._flatten_obs(observation) |
| action_mask = np.array(observation.get("action_mask", [1]*6), dtype=bool) |
|
|
| action, _ = self.model.predict( |
| obs_vec, |
| action_masks=action_mask, |
| deterministic=True, |
| ) |
| return int(action) |
|
|