StockZero-v2 / training-script-v2.py
nirajandhakal's picture
Create training-script-v2.py
450608e verified
import chess
import chess.engine
import numpy as np
import tensorflow as tf
import time
import os
import datetime
import numpy as np
# --- 1. Neural Network (Policy and Value Network) ---
class PolicyValueNetwork(tf.keras.Model):
def __init__(self, num_moves):
super(PolicyValueNetwork, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu', padding='same') # Removed input_shape
self.flatten = tf.keras.layers.Flatten()
self.dense_policy = tf.keras.layers.Dense(num_moves, activation='softmax', name='policy_head')
self.dense_value = tf.keras.layers.Dense(1, activation='tanh', name='value_head')
def call(self, inputs):
x = self.conv1(inputs)
x = self.flatten(x)
policy = self.dense_policy(x)
value = self.dense_value(x)
return policy, value
# --- 2. Board Representation and Preprocessing ---
def board_to_input(board):
piece_types = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING]
input_planes = np.zeros((8, 8, 12), dtype=np.float32)
for piece_type_index, piece_type in enumerate(piece_types):
for square in chess.SQUARES:
piece = board.piece_at(square)
if piece is not None:
if piece.piece_type == piece_type:
plane_index = piece_type_index if piece.color == chess.WHITE else piece_type_index + 6
row, col = chess.square_rank(square), chess.square_file(square)
input_planes[row, col, plane_index] = 1.0
return input_planes
def get_legal_moves_mask(board):
legal_moves = list(board.legal_moves)
move_indices = [move_to_index(move) for move in legal_moves]
# --- Defensive Check: Filter out-of-bounds indices ---
valid_move_indices = []
out_of_bounds_indices = []
for index in move_indices:
if 0 <= index < NUM_POSSIBLE_MOVES:
valid_move_indices.append(index)
else:
out_of_bounds_indices.append(index)
mask = np.zeros(NUM_POSSIBLE_MOVES, dtype=np.float32)
mask[valid_move_indices] = 1.0
return mask
# --- 3. Move Encoding/Decoding (Correct and Deterministic Implementation) ---
NUM_POSSIBLE_MOVES = 4672 # Correct value based on deterministic encoding
def move_to_index(move):
"""Standard, deterministic move to index conversion (UCI-like encoding)."""
index = 0
# Non-promotion moves (most common)
if move.promotion is None:
index = move.from_square * 64 + move.to_square # Source and target squares
# Promotion moves - use offsets to separate them from non-promotion indices
elif move.promotion == chess.KNIGHT:
index = 4096 + move.to_square # Knight promotions start after non-promotion moves
elif move.promotion == chess.BISHOP:
index = 4096 + 64 + move.to_square # Bishop promotions after Knights
elif move.promotion == chess.ROOK:
index = 4096 + 64*2 + move.to_square # Rook promotions after Bishops
elif move.promotion == chess.QUEEN:
index = 4096 + 64*3 + move.to_square # Queen promotions after Rooks
else:
raise ValueError(f"Unknown promotion piece type: {move.promotion}")
return index
def index_to_move(index, board):
"""Standard, deterministic index to move conversion (index to chess.Move)."""
if 0 <= index < 4096: # Non-promotion moves
from_square = index // 64
to_square = index % 64
promotion = None
elif 4096 <= index < 4096 + 64: # Knight promotions
from_square_rank = chess.square_rank(chess.A8) - 1 # Rank 8 for White Pawns, Rank 1 for Black Pawns, -1 for index conversion
from_square = chess.square(chess.square_file(chess.A1), from_square_rank) # Assume promotion from any file on promotion rank. Refine as needed.
to_square = index - 4096
promotion = chess.KNIGHT
elif 4096 + 64 <= index < 4096 + 64*2: # Bishop promotions
from_square_rank = chess.square_rank(chess.A8) - 1
from_square = chess.square(chess.square_file(chess.A1), from_square_rank)
to_square = index - (4096 + 64)
promotion = chess.BISHOP
elif 4096 + 64*2 <= index < 4096 + 64*3: # Rook promotions
from_square_rank = chess.square_rank(chess.A8) - 1
from_square = chess.square(chess.square_file(chess.A1), from_square_rank)
to_square = index - (4096 + 64*2)
promotion = chess.ROOK
elif 4096 + 64*3 <= index < NUM_POSSIBLE_MOVES: # Queen promotions
from_square_rank = chess.square_rank(chess.A8) - 1
from_square = chess.square(chess.square_file(chess.A1), from_square_rank)
to_square = index - (4096 + 64*3)
promotion = chess.QUEEN
else: # Invalid index
return None
move = chess.Move(from_square, to_square, promotion=promotion)
if move in board.legal_moves:
return move
return None # Move is not legal
def get_game_result_value(board):
if board.is_checkmate():
return 1 if board.turn == chess.BLACK else -1
elif board.is_stalemate() or board.is_insufficient_material() or board.is_seventyfive_moves() or board.is_fivefold_repetition() or board.is_variant_draw():
return 0
else:
return 0
# --- 4. Monte Carlo Tree Search (MCTS) ---
class MCTSNode:
def __init__(self, board, parent=None, prior_prob=0):
self.board = board.copy()
self.parent = parent
self.children = {}
self.visits = 0
self.value_sum = 0
self.prior_prob = prior_prob
self.policy_prob = 0
self.value = 0
def select_child(self, exploration_constant=1.4):
best_child = None
best_ucb = -float('inf')
for move, child in self.children.items():
ucb = child.value + exploration_constant * child.prior_prob * np.sqrt(self.visits) / (1 + child.visits)
if ucb > best_ucb:
best_ucb = ucb
best_child = child
return best_child
def expand(self, policy_probs):
legal_moves = list(self.board.legal_moves)
for move in legal_moves:
move_index = move_to_index(move)
prior_prob = policy_probs[move_index]
self.children[move] = MCTSNode(chess.Board(fen=self.board.fen()), parent=self, prior_prob=prior_prob)
def evaluate(self, policy_value_net):
input_board = board_to_input(self.board)
policy_output, value_output = policy_value_net(np.expand_dims(input_board, axis=0))
policy_probs = policy_output.numpy()[0]
value = value_output.numpy()[0][0]
legal_moves_mask = get_legal_moves_mask(self.board)
masked_policy_probs = policy_probs * legal_moves_mask
if np.sum(masked_policy_probs) > 0:
masked_policy_probs /= np.sum(masked_policy_probs)
else:
masked_policy_probs = legal_moves_mask / np.sum(legal_moves_mask)
self.policy_prob = masked_policy_probs
self.value = value
return value, masked_policy_probs
def backup(self, value):
self.visits += 1
self.value_sum += value
self.value = self.value_sum / self.visits
if self.parent:
self.parent.backup(-value)
def run_mcts(root_node, policy_value_net, num_simulations):
for _ in range(num_simulations):
node = root_node
search_path = [node]
while node.children and not node.board.is_game_over():
node = node.select_child()
search_path.append(node)
leaf_node = search_path[-1]
if not leaf_node.board.is_game_over():
value, policy_probs = leaf_node.evaluate(policy_value_net)
leaf_node.expand(policy_probs)
else:
value = get_game_result_value(leaf_node.board)
leaf_node.backup(value)
return choose_best_move_from_mcts(root_node)
def choose_best_move_from_mcts(root_node, temperature=0.0):
if temperature == 0:
best_move = max(root_node.children, key=lambda move: root_node.children[move].visits)
else:
visits = [root_node.children[move].visits for move in root_node.children]
move_probs = np.array(visits) ** (1/temperature)
move_probs = move_probs / np.sum(move_probs)
moves = list(root_node.children.keys())
best_move = np.random.choice(moves, p=move_probs)
return best_move
# --- 5. RL Engine Class ---
class RLEngine:
def __init__(self, policy_value_net, num_simulations_per_move=100):
self.policy_value_net = policy_value_net
self.num_simulations_per_move = num_simulations_per_move
def choose_move(self, board):
root_node = MCTSNode(board)
best_move = run_mcts(root_node, self.policy_value_net, self.num_simulations_per_move)
return best_move
# --- 6. Training Functions ---
def self_play_game(engine, model, num_simulations):
game_history = []
board = chess.Board()
while not board.is_game_over():
root_node = MCTSNode(board)
run_mcts(root_node, model, num_simulations)
policy_targets = create_policy_targets_from_mcts_visits(root_node)
game_history.append((board.fen(), policy_targets))
best_move = choose_best_move_from_mcts(root_node, temperature=0.8) # Exploration temperature
board.push(best_move)
game_result = get_game_result_value(board)
for i in range(len(game_history)):
fen, policy_target = game_history[i]
game_history[i] = (fen, policy_target, game_result if board.turn == chess.WHITE else -game_result)
return game_history
def create_policy_targets_from_mcts_visits(root_node):
policy_targets = np.zeros(NUM_POSSIBLE_MOVES, dtype=np.float32)
for move, child_node in root_node.children.items():
move_index = move_to_index(move)
policy_targets[move_index] = child_node.visits
policy_targets /= np.sum(policy_targets)
return policy_targets
def train_step(model, board_inputs, policy_targets, value_targets, optimizer):
with tf.GradientTape() as tape:
policy_outputs, value_outputs = model(board_inputs)
policy_loss = tf.keras.losses.CategoricalCrossentropy()(policy_targets, policy_outputs)
value_loss = tf.keras.losses.MeanSquaredError()(value_targets, value_outputs)
total_loss = policy_loss + value_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return total_loss, policy_loss, value_loss
def train_network(model, game_histories, optimizer, epochs=10, batch_size=32):
all_board_inputs = []
all_policy_targets = []
all_value_targets = []
for game_history in game_histories:
for fen, policy_target, game_result in game_history:
board = chess.Board(fen)
all_board_inputs.append(board_to_input(board))
all_policy_targets.append(policy_target)
all_value_targets.append(np.array([game_result]))
all_board_inputs = np.array(all_board_inputs)
all_policy_targets = np.array(all_policy_targets)
all_value_targets = np.array(all_value_targets)
dataset = tf.data.Dataset.from_tensor_slices((all_board_inputs, all_policy_targets, all_value_targets))
dataset = dataset.shuffle(buffer_size=len(all_board_inputs)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
for epoch in range(epochs):
print(f"Epoch {epoch+1}/{epochs}")
for batch_inputs, batch_policy_targets, batch_value_targets in dataset:
loss, p_loss, v_loss = train_step(model, batch_inputs, batch_policy_targets, batch_value_targets, optimizer)
print(f"  Loss: {loss:.4f}, Policy Loss: {p_loss:.4f}, Value Loss: {v_loss:.4f}")
# --- 7. Main Training Execution in Colab ---
if __name__ == "__main__":
# --- Check GPU Availability in Colab ---
if tf.config.list_physical_devices('GPU'):
print("\n\nGPU is available and will be used for training.\n\n")
gpu_device = '/GPU:0' # Use GPU 0 if available
else:
print("\n\nGPU is not available. Training will use CPU (may be slow).\n\n")
gpu_device = '/CPU:0'
with tf.device(gpu_device): # Explicitly place operations on GPU (if available)
# Initialize Neural Network, Engine, and Optimizer
policy_value_net = PolicyValueNetwork(NUM_POSSIBLE_MOVES)
engine = RLEngine(policy_value_net, num_simulations_per_move=100)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
# --- Training Parameters ---
num_self_play_games = 50 # Adjust for longer training
epochs = 5 # Adjust for longer training
# --- Run Self-Play and Training ---
game_histories = []
start_time = time.time()
# --- Model Save Directory in Colab ---
MODEL_SAVE_DIR = "models_colab" # Directory to save model in Colab
os.makedirs(MODEL_SAVE_DIR, exist_ok=True) # Create directory if it doesn't exist
for i in range(num_self_play_games):
print(f"Self-play game {i+1}/{num_self_play_games} \n")
game_history = self_play_game(engine, policy_value_net, num_simulations=50) # Reduced sims for faster games
game_histories.append(game_history)
train_network(policy_value_net, game_histories, optimizer, epochs=epochs)
end_time = time.time()
training_time = end_time - start_time
print(f"\n\n ---- Training completed in {training_time:.2f} seconds. ---- \n")
# --- Save the trained model (architecture + weights) in SavedModel format ---
current_datetime = datetime.datetime.now()
model_version_str = current_datetime.strftime("%Y-%m-%d-%H%M") # Added hour and minute for uniqueness
model_save_path = os.path.join(MODEL_SAVE_DIR, f"StockZero-{model_version_str}.weights.h5") # Versioned filename
policy_value_net.save_weights(model_save_path) # Saves model weights
print(f"Trained model weights saved to '{model_save_path}' in '{MODEL_SAVE_DIR}' directory in Colab.")
# --- Download the saved model (for use outside Colab) ---
# --- (Optional: Uncomment to download the saved model as a zip file) ---
import shutil
zip_file_path = f"StockZero-{model_version_str}"
shutil.make_archive(zip_file_path, 'zip', MODEL_SAVE_DIR) # Create zip archive
print(f"Model directory zipped to '{zip_file_path}'. Download this file.")
from google.colab import files
files.download(f"{zip_file_path}.zip") # Trigger download in Colab
print("\n\n ----- Training finished. ------- \n\n")