import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.distributions import Categorical import numpy as np import ale_py import gymnasium as gym import matplotlib.pyplot as plt from collections import deque # Register ALE environments gym.register_envs(ale_py) # Set random seeds for reproducibility torch.manual_seed(42) np.random.seed(42) # Check if GPU is available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # ==================== Policy Networks ==================== class CartPolePolicy(nn.Module): """Policy network for CartPole environment""" def __init__(self, state_dim, action_dim, hidden_dim=128): super(CartPolePolicy, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, action_dim) # Initialize weights self._initialize_weights() def _initialize_weights(self): """Initialize network weights""" for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0.0) def forward(self, x): x = F.relu(self.fc1(x)) x = self.fc2(x) return F.softmax(x, dim=-1) class PongPolicy(nn.Module): """Policy network for Pong with CNN architecture""" def __init__(self, action_dim=2): super(PongPolicy, self).__init__() # CNN layers for processing 80x80 images self.conv1 = nn.Conv2d(1, 16, kernel_size=8, stride=4) self.conv2 = nn.Conv2d(16, 32, kernel_size=4, stride=2) # Calculate size after convolutions: 80 -> 19 -> 8 self.fc1 = nn.Linear(32 * 8 * 8, 256) self.fc2 = nn.Linear(256, action_dim) # Initialize weights for better training stability self._initialize_weights() def _initialize_weights(self): """Initialize network weights with proper initialization""" for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_uniform_(m.weight, nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0.0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0.0) def forward(self, x): # x shape: (batch, 80, 80) -> add channel dimension if len(x.shape) == 2: x = x.unsqueeze(0).unsqueeze(0) elif len(x.shape) == 3: x = x.unsqueeze(1) x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.softmax(x, dim=-1) # ==================== Helper Functions ==================== def preprocess(image): """Prepro 210x160x3 uint8 frame into 6400 (80x80) 2D float array""" image = image[35:195] # crop image = image[::2, ::2, 0] # downsample by factor of 2 image[image == 144] = 0 # erase background (background type 1) image[image == 109] = 0 # erase background (background type 2) image[image != 0] = 1 # everything else (paddles, ball) just set to 1 return np.reshape(image.astype(float).ravel(), [80, 80]) def compute_returns(rewards, gamma): """Compute discounted returns for each timestep""" returns = [] R = 0 for r in reversed(rewards): R = r + gamma * R returns.insert(0, R) returns = torch.tensor(returns, dtype=torch.float32).to(device) # Normalize returns for more stable training if len(returns) > 1: returns = (returns - returns.mean()) / (returns.std() + 1e-8) return returns def moving_average(data, window_size): """Compute moving average""" if len(data) < window_size: return np.array([np.mean(data[:i+1]) for i in range(len(data))]) moving_avg = [] for i in range(len(data)): if i < window_size: moving_avg.append(np.mean(data[:i+1])) else: moving_avg.append(np.mean(data[i-window_size+1:i+1])) return np.array(moving_avg) # ==================== Policy Gradient Algorithm ==================== def train_policy_gradient(env_name, policy, optimizer, gamma, num_episodes, max_steps=None, is_pong=False, action_map=None): """ Train policy using REINFORCE algorithm Args: env_name: Name of the gym environment policy: Policy network optimizer: PyTorch optimizer gamma: Discount factor num_episodes: Number of training episodes max_steps: Maximum steps per episode (None for default) is_pong: Whether this is Pong environment action_map: Mapping from policy action to env action (for Pong) """ env = gym.make(env_name) episode_rewards = [] for episode in range(num_episodes): state, _ = env.reset() # Preprocess state for Pong if is_pong: state = preprocess(state) prev_frame = None # Track previous frame for motion log_probs = [] rewards = [] done = False step = 0 while not done: # For Pong, use frame difference (motion signal) if is_pong: cur_frame = state if prev_frame is not None: state_input = cur_frame - prev_frame else: state_input = np.zeros_like(cur_frame, dtype=np.float32) prev_frame = cur_frame state_tensor = torch.FloatTensor(state_input).to(device) else: # Convert state to tensor state_tensor = torch.FloatTensor(state).to(device) # Get action probabilities action_probs = policy(state_tensor) # Sample action from the distribution dist = Categorical(action_probs) action = dist.sample() log_prob = dist.log_prob(action) # Map action for Pong (0,1 -> 2,3) if is_pong: env_action = action_map[action.item()] else: env_action = action.item() # Take action in environment next_state, reward, terminated, truncated, _ = env.step(env_action) done = terminated or truncated # Preprocess next state for Pong if is_pong: next_state = preprocess(next_state) # Store log probability and reward log_probs.append(log_prob) rewards.append(reward) state = next_state step += 1 if max_steps and step >= max_steps: break # Compute returns returns = compute_returns(rewards, gamma) # Compute policy gradient loss policy_loss = [] for log_prob, R in zip(log_probs, returns): policy_loss.append(-log_prob * R) # Optimize policy optimizer.zero_grad() policy_loss = torch.stack(policy_loss).sum() policy_loss.backward() # Gradient clipping for training stability torch.nn.utils.clip_grad_norm_(policy.parameters(), max_norm=1.0) optimizer.step() # Record episode reward episode_reward = sum(rewards) episode_rewards.append(episode_reward) # Print progress if (episode + 1) % 100 == 0: avg_reward = np.mean(episode_rewards[-100:]) print(f"Episode {episode + 1}/{num_episodes}, " f"Avg Reward (last 100): {avg_reward:.2f}") # Save checkpoint for Pong every 500 episodes if is_pong and (episode + 1) % 500 == 0: checkpoint_path = f'pong_checkpoint_ep{episode + 1}.pth' torch.save({ 'episode': episode + 1, 'policy_state_dict': policy.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'episode_rewards': episode_rewards, }, checkpoint_path) print(f" → Checkpoint saved: {checkpoint_path}") env.close() return episode_rewards def evaluate_policy(env_name, policy, num_episodes=500, is_pong=False, action_map=None): """Evaluate trained policy over multiple episodes""" env = gym.make(env_name) eval_rewards = [] for episode in range(num_episodes): state, _ = env.reset() if is_pong: state = preprocess(state) prev_frame = None # Track previous frame for motion episode_reward = 0 done = False while not done: # For Pong, use frame difference (motion signal) if is_pong: cur_frame = state if prev_frame is not None: state_input = cur_frame - prev_frame else: state_input = np.zeros_like(cur_frame, dtype=np.float32) prev_frame = cur_frame state_tensor = torch.FloatTensor(state_input).to(device) else: state_tensor = torch.FloatTensor(state).to(device) with torch.no_grad(): action_probs = policy(state_tensor) action = torch.argmax(action_probs).item() if is_pong: env_action = action_map[action] else: env_action = action next_state, reward, terminated, truncated, _ = env.step(env_action) done = terminated or truncated if is_pong: next_state = preprocess(next_state) episode_reward += reward state = next_state eval_rewards.append(episode_reward) if (episode + 1) % 100 == 0: print(f"Evaluated {episode + 1}/{num_episodes} episodes") env.close() return eval_rewards def plot_results(episode_rewards, eval_rewards, title, save_prefix): """Plot training curve and evaluation histogram""" fig, axes = plt.subplots(1, 2, figsize=(15, 5)) # Plot training curve ax1 = axes[0] episodes = np.arange(1, len(episode_rewards) + 1) ma = moving_average(episode_rewards, 100) ax1.plot(episodes, episode_rewards, alpha=0.3, label='Episode Reward') ax1.plot(episodes, ma, linewidth=2, label='Moving Average (100 episodes)') ax1.set_xlabel('Episode') ax1.set_ylabel('Reward') ax1.set_title(f'{title} - Training Curve') ax1.legend() ax1.grid(True, alpha=0.3) # Plot evaluation histogram ax2 = axes[1] mean_reward = np.mean(eval_rewards) std_reward = np.std(eval_rewards) ax2.hist(eval_rewards, bins=30, edgecolor='black', alpha=0.7) ax2.axvline(mean_reward, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_reward:.2f}') ax2.set_xlabel('Episode Reward') ax2.set_ylabel('Frequency') ax2.set_title(f'{title} - Evaluation Histogram (500 episodes)\n' f'Mean: {mean_reward:.2f}, Std: {std_reward:.2f}') ax2.legend() ax2.grid(True, alpha=0.3, axis='y') plt.tight_layout() plt.savefig(f'{save_prefix}_results.png', dpi=150, bbox_inches='tight') plt.show() print(f"\n{title} Evaluation Results:") print(f"Mean Reward: {mean_reward:.2f}") print(f"Std Reward: {std_reward:.2f}") # ==================== Main Training Scripts ==================== def train_cartpole(): """Train CartPole-v1""" print("\n" + "="*60) print("Training CartPole-v1") print("="*60 + "\n") # Environment parameters env = gym.make('CartPole-v1') state_dim = env.observation_space.shape[0] action_dim = env.action_space.n env.close() # Hyperparameters gamma = 0.95 learning_rate = 0.01 num_episodes = 1000 # Initialize policy and optimizer policy = CartPolePolicy(state_dim, action_dim).to(device) optimizer = optim.Adam(policy.parameters(), lr=learning_rate) # Train episode_rewards = train_policy_gradient( 'CartPole-v1', policy, optimizer, gamma, num_episodes ) # Evaluate print("\nEvaluating trained policy...") eval_rewards = evaluate_policy('CartPole-v1', policy, num_episodes=500) # Plot results plot_results(episode_rewards, eval_rewards, 'CartPole-v1', 'cartpole') # Save model torch.save(policy.state_dict(), 'cartpole_policy.pth') print("\nModel saved as 'cartpole_policy.pth'") return policy, episode_rewards, eval_rewards def train_pong(): """Train Pong-v5""" print("\n" + "="*60) print("Training Pong-v5") print("="*60 + "\n") # Hyperparameters gamma = 0.99 learning_rate = 0.001 # Lower learning rate for stability num_episodes = 1000 # Pong requires more episodes # Action mapping: policy outputs 0 or 1, map to RIGHT(2) or LEFT(3) action_map = [2, 3] # Index 0->RIGHT(2), Index 1->LEFT(3) # Initialize policy and optimizer policy = PongPolicy(action_dim=2).to(device) optimizer = optim.Adam(policy.parameters(), lr=learning_rate) print(f"Using learning rate: {learning_rate} (reduced for stability)") print(f"Action mapping: 0->RIGHT(2), 1->LEFT(3)") print(f"Gradient clipping: max_norm=1.0") print(f"Weight initialization: Kaiming (Conv) + Xavier (FC)\n") # Train with periodic checkpointing print("Starting training (checkpoints saved every 500 episodes)...\n") episode_rewards = train_policy_gradient( 'ALE/Pong-v5', policy, optimizer, gamma, num_episodes, is_pong=True, action_map=action_map ) print("\nTraining completed!") # Evaluate print("\nEvaluating trained policy...") eval_rewards = evaluate_policy( 'ALE/Pong-v5', policy, num_episodes=500, is_pong=True, action_map=action_map ) # Plot results plot_results(episode_rewards, eval_rewards, 'Pong-v5', 'pong') # Save model torch.save(policy.state_dict(), 'pong_policy.pth') print("\nModel saved as 'pong_policy.pth'") return policy, episode_rewards, eval_rewards # ==================== Run Training ==================== if __name__ == "__main__": # Train CartPole #cartpole_policy, cartpole_train_rewards, cartpole_eval_rewards = train_cartpole() # Train Pong (this will take longer) #print("\n\nNote: Pong training will take significantly longer (may take hours)") #print("You may want to reduce num_episodes if just testing the code.\n") # Uncomment the line below to train Pong pong_policy, pong_train_rewards, pong_eval_rewards = train_pong()