trioskosmos commited on
Commit
25d7b1e
·
verified ·
1 Parent(s): 9ae5bf0

Upload ai/agents/fast_mcts.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/agents/fast_mcts.py +164 -0
ai/agents/fast_mcts.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Tuple
4
+
5
+ import numpy as np
6
+
7
+ # Assuming GameState interface from existing code
8
+ # We import the actual GameState to be safe
9
+ from engine.game.game_state import GameState
10
+
11
+
12
+ @dataclass
13
+ class HeuristicMCTSConfig:
14
+ num_simulations: int = 100
15
+ c_puct: float = 1.4
16
+ depth_limit: int = 50
17
+
18
+
19
+ class HeuristicNode:
20
+ def __init__(self, parent=None, prior=1.0):
21
+ self.parent = parent
22
+ self.children: Dict[int, "HeuristicNode"] = {}
23
+ self.visit_count = 0
24
+ self.value_sum = 0.0
25
+ self.prior = prior
26
+ self.untried_actions: List[int] = []
27
+ self.player_just_moved = -1
28
+
29
+ @property
30
+ def value(self):
31
+ if self.visit_count == 0:
32
+ return 0
33
+ return self.value_sum / self.visit_count
34
+
35
+ def ucb_score(self, c_puct):
36
+ # Standard UCB1
37
+ if self.visit_count == 0:
38
+ return float("inf")
39
+
40
+ # UCB = Q + c * sqrt(ln(N_parent) / N_child)
41
+ # Note: AlphaZero uses a slightly different variant with Priors.
42
+ # Since we don't have a policy network, we assume uniform priors or just use standard UCB.
43
+ # Let's use standard UCB for "MCTS without training"
44
+
45
+ parent_visits = self.parent.visit_count if self.parent else 1
46
+ exploitation = self.value
47
+ exploration = c_puct * math.sqrt(math.log(parent_visits) / self.visit_count)
48
+ return exploitation + exploration
49
+
50
+
51
+ class HeuristicMCTS:
52
+ """
53
+ MCTS that uses random rollouts and heuristics instead of a Neural Network.
54
+ This works 'without training' because it relies on the game rules (simulation)
55
+ and hard-coded domain knowledge (rollout policy / terminal evaluation).
56
+ """
57
+
58
+ def __init__(self, config: HeuristicMCTSConfig):
59
+ self.config = config
60
+ self.root = None
61
+
62
+ def search(self, state: GameState) -> int:
63
+ self.root = HeuristicNode(prior=1.0)
64
+ # We need to copy state for the root? Actually search loop copies it.
65
+ # But we need to know legal actions.
66
+ legal = state.get_legal_actions()
67
+ self.root.untried_actions = [i for i, x in enumerate(legal) if x]
68
+ self.root.player_just_moved = 1 - state.current_player # Parent moved previously
69
+
70
+ for _ in range(self.config.num_simulations):
71
+ node = self.root
72
+ sim_state = state.copy()
73
+
74
+ # 1. Selection
75
+ path = [node]
76
+ while node.children and not node.untried_actions:
77
+ action, node = self._select_best_step(node)
78
+ sim_state = sim_state.step(action)
79
+ path.append(node)
80
+
81
+ # 2. Expansion
82
+ if node.untried_actions:
83
+ action = node.untried_actions.pop()
84
+ sim_state = sim_state.step(action)
85
+ child = HeuristicNode(parent=node, prior=1.0)
86
+ child.player_just_moved = 1 - sim_state.current_player # The player who took 'action'
87
+ node.children[action] = child
88
+ node = child
89
+ path.append(node)
90
+
91
+ # 3. Simulation (Rollout)
92
+ # Run until terminal or depth limit
93
+ depth = 0
94
+ while not sim_state.is_terminal() and depth < self.config.depth_limit:
95
+ legal = sim_state.get_legal_actions()
96
+ legal_indices = [i for i, x in enumerate(legal) if x]
97
+ if not legal_indices:
98
+ break
99
+ # Random Policy (No training required)
100
+ action = np.random.choice(legal_indices)
101
+ sim_state = sim_state.step(action)
102
+ depth += 1
103
+
104
+ # 4. Backpropagation
105
+ # If terminal, get reward. If cutoff, use heuristic.
106
+ if sim_state.is_terminal():
107
+ # reward is relative to current_player
108
+ # We need reward from perspective of root player?
109
+ # Usually standard MCTS backprops values flipping each layer
110
+ reward = sim_state.get_reward(state.current_player) # 1.0 if root wins
111
+ else:
112
+ reward = self._heuristic_eval(sim_state, state.current_player)
113
+
114
+ for i, n in enumerate(reversed(path)):
115
+ n.visit_count += 1
116
+ # If n.player_just_moved == root_player, this node represents a state AFTER root moved.
117
+ # So its value should be positive if root won.
118
+ # Standard: if player_just_moved won, +1.
119
+
120
+ # Simpler view: All values tracked relative to Root Player.
121
+ n.value_sum += reward
122
+
123
+ # Select best move (robust child)
124
+ if not self.root.children:
125
+ return 0 # Fallback
126
+
127
+ best_action = max(self.root.children.items(), key=lambda item: item[1].visit_count)[0]
128
+ return best_action
129
+
130
+ def _select_best_step(self, node: HeuristicNode) -> Tuple[int, HeuristicNode]:
131
+ # Standard UCB
132
+ best_score = -float("inf")
133
+ best_item = None
134
+
135
+ for action, child in node.children.items():
136
+ score = child.ucb_score(self.config.c_puct)
137
+ if score > best_score:
138
+ best_score = score
139
+ best_item = (action, child)
140
+
141
+ return best_item
142
+
143
+ def _heuristic_eval(self, state: GameState, root_player: int) -> float:
144
+ """
145
+ Evaluate state without a neural network.
146
+ Logic: More blades/hearts/lives = Better.
147
+ """
148
+ p = state.players[root_player]
149
+ opp = state.players[1 - root_player]
150
+
151
+ # Score = (My Lives - Opp Lives) + 0.1 * (My Power - Opp Power)
152
+ score = 0.0
153
+ score += (len(p.success_lives) - len(opp.success_lives)) * 0.5
154
+
155
+ my_power = p.get_total_blades(state.member_db)
156
+ opp_power = opp.get_total_blades(state.member_db)
157
+ score += (my_power - opp_power) * 0.05
158
+
159
+ # Clamp to [-1, 1]
160
+ return max(-1.0, min(1.0, score))
161
+
162
+
163
+ if __name__ == "__main__":
164
+ pass