kevinwang676's picture
Update test2.py
15d57a8 verified
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np
import ale_py
import gymnasium as gym
import matplotlib.pyplot as plt
from collections import deque
# Register ALE environments
gym.register_envs(ale_py)
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# ==================== Policy Networks ====================
class CartPolePolicy(nn.Module):
"""Policy network for CartPole environment"""
def __init__(self, state_dim, action_dim, hidden_dim=128):
super(CartPolePolicy, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, action_dim)
# Initialize weights
self._initialize_weights()
def _initialize_weights(self):
"""Initialize network weights"""
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0.0)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.softmax(x, dim=-1)
class PongPolicy(nn.Module):
"""Policy network for Pong with CNN architecture"""
def __init__(self, action_dim=2):
super(PongPolicy, self).__init__()
# CNN layers for processing 80x80 images
self.conv1 = nn.Conv2d(1, 16, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(16, 32, kernel_size=4, stride=2)
# Calculate size after convolutions: 80 -> 19 -> 8
self.fc1 = nn.Linear(32 * 8 * 8, 256)
self.fc2 = nn.Linear(256, action_dim)
# Initialize weights for better training stability
self._initialize_weights()
def _initialize_weights(self):
"""Initialize network weights with proper initialization"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0.0)
def forward(self, x):
# x shape: (batch, 80, 80) -> add channel dimension
if len(x.shape) == 2:
x = x.unsqueeze(0).unsqueeze(0)
elif len(x.shape) == 3:
x = x.unsqueeze(1)
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.softmax(x, dim=-1)
# ==================== Helper Functions ====================
def preprocess(image):
"""Prepro 210x160x3 uint8 frame into 6400 (80x80) 2D float array"""
image = image[35:195] # crop
image = image[::2, ::2, 0] # downsample by factor of 2
image[image == 144] = 0 # erase background (background type 1)
image[image == 109] = 0 # erase background (background type 2)
image[image != 0] = 1 # everything else (paddles, ball) just set to 1
return np.reshape(image.astype(float).ravel(), [80, 80])
def compute_returns(rewards, gamma):
"""Compute discounted returns for each timestep"""
returns = []
R = 0
for r in reversed(rewards):
R = r + gamma * R
returns.insert(0, R)
returns = torch.tensor(returns, dtype=torch.float32).to(device)
# Normalize returns for more stable training
if len(returns) > 1:
returns = (returns - returns.mean()) / (returns.std() + 1e-8)
return returns
def moving_average(data, window_size):
"""Compute moving average"""
if len(data) < window_size:
return np.array([np.mean(data[:i+1]) for i in range(len(data))])
moving_avg = []
for i in range(len(data)):
if i < window_size:
moving_avg.append(np.mean(data[:i+1]))
else:
moving_avg.append(np.mean(data[i-window_size+1:i+1]))
return np.array(moving_avg)
# ==================== Policy Gradient Algorithm ====================
def train_policy_gradient(env_name, policy, optimizer, gamma, num_episodes,
max_steps=None, is_pong=False, action_map=None):
"""
Train policy using REINFORCE algorithm
Args:
env_name: Name of the gym environment
policy: Policy network
optimizer: PyTorch optimizer
gamma: Discount factor
num_episodes: Number of training episodes
max_steps: Maximum steps per episode (None for default)
is_pong: Whether this is Pong environment
action_map: Mapping from policy action to env action (for Pong)
"""
env = gym.make(env_name)
episode_rewards = []
for episode in range(num_episodes):
state, _ = env.reset()
# Preprocess state for Pong
if is_pong:
state = preprocess(state)
prev_frame = None # Track previous frame for motion
log_probs = []
rewards = []
done = False
step = 0
while not done:
# For Pong, use frame difference (motion signal)
if is_pong:
cur_frame = state
if prev_frame is not None:
state_input = cur_frame - prev_frame
else:
state_input = np.zeros_like(cur_frame, dtype=np.float32)
prev_frame = cur_frame
state_tensor = torch.FloatTensor(state_input).to(device)
else:
# Convert state to tensor
state_tensor = torch.FloatTensor(state).to(device)
# Get action probabilities
action_probs = policy(state_tensor)
# Sample action from the distribution
dist = Categorical(action_probs)
action = dist.sample()
log_prob = dist.log_prob(action)
# Map action for Pong (0,1 -> 2,3)
if is_pong:
env_action = action_map[action.item()]
else:
env_action = action.item()
# Take action in environment
next_state, reward, terminated, truncated, _ = env.step(env_action)
done = terminated or truncated
# Preprocess next state for Pong
if is_pong:
next_state = preprocess(next_state)
# Store log probability and reward
log_probs.append(log_prob)
rewards.append(reward)
state = next_state
step += 1
if max_steps and step >= max_steps:
break
# Compute returns
returns = compute_returns(rewards, gamma)
# Compute policy gradient loss
policy_loss = []
for log_prob, R in zip(log_probs, returns):
policy_loss.append(-log_prob * R)
# Optimize policy
optimizer.zero_grad()
policy_loss = torch.stack(policy_loss).sum()
policy_loss.backward()
# Gradient clipping for training stability
torch.nn.utils.clip_grad_norm_(policy.parameters(), max_norm=1.0)
optimizer.step()
# Record episode reward
episode_reward = sum(rewards)
episode_rewards.append(episode_reward)
# Print progress
if (episode + 1) % 100 == 0:
avg_reward = np.mean(episode_rewards[-100:])
print(f"Episode {episode + 1}/{num_episodes}, "
f"Avg Reward (last 100): {avg_reward:.2f}")
# Save checkpoint for Pong every 500 episodes
if is_pong and (episode + 1) % 500 == 0:
checkpoint_path = f'pong_checkpoint_ep{episode + 1}.pth'
torch.save({
'episode': episode + 1,
'policy_state_dict': policy.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'episode_rewards': episode_rewards,
}, checkpoint_path)
print(f" → Checkpoint saved: {checkpoint_path}")
env.close()
return episode_rewards
def evaluate_policy(env_name, policy, num_episodes=500, is_pong=False, action_map=None):
"""Evaluate trained policy over multiple episodes"""
env = gym.make(env_name)
eval_rewards = []
for episode in range(num_episodes):
state, _ = env.reset()
if is_pong:
state = preprocess(state)
prev_frame = None # Track previous frame for motion
episode_reward = 0
done = False
while not done:
# For Pong, use frame difference (motion signal)
if is_pong:
cur_frame = state
if prev_frame is not None:
state_input = cur_frame - prev_frame
else:
state_input = np.zeros_like(cur_frame, dtype=np.float32)
prev_frame = cur_frame
state_tensor = torch.FloatTensor(state_input).to(device)
else:
state_tensor = torch.FloatTensor(state).to(device)
with torch.no_grad():
action_probs = policy(state_tensor)
action = torch.argmax(action_probs).item()
if is_pong:
env_action = action_map[action]
else:
env_action = action
next_state, reward, terminated, truncated, _ = env.step(env_action)
done = terminated or truncated
if is_pong:
next_state = preprocess(next_state)
episode_reward += reward
state = next_state
eval_rewards.append(episode_reward)
if (episode + 1) % 100 == 0:
print(f"Evaluated {episode + 1}/{num_episodes} episodes")
env.close()
return eval_rewards
def plot_results(episode_rewards, eval_rewards, title, save_prefix):
"""Plot training curve and evaluation histogram"""
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
# Plot training curve
ax1 = axes[0]
episodes = np.arange(1, len(episode_rewards) + 1)
ma = moving_average(episode_rewards, 100)
ax1.plot(episodes, episode_rewards, alpha=0.3, label='Episode Reward')
ax1.plot(episodes, ma, linewidth=2, label='Moving Average (100 episodes)')
ax1.set_xlabel('Episode')
ax1.set_ylabel('Reward')
ax1.set_title(f'{title} - Training Curve')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot evaluation histogram
ax2 = axes[1]
mean_reward = np.mean(eval_rewards)
std_reward = np.std(eval_rewards)
ax2.hist(eval_rewards, bins=30, edgecolor='black', alpha=0.7)
ax2.axvline(mean_reward, color='red', linestyle='--', linewidth=2,
label=f'Mean: {mean_reward:.2f}')
ax2.set_xlabel('Episode Reward')
ax2.set_ylabel('Frequency')
ax2.set_title(f'{title} - Evaluation Histogram (500 episodes)\n'
f'Mean: {mean_reward:.2f}, Std: {std_reward:.2f}')
ax2.legend()
ax2.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.savefig(f'{save_prefix}_results.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"\n{title} Evaluation Results:")
print(f"Mean Reward: {mean_reward:.2f}")
print(f"Std Reward: {std_reward:.2f}")
# ==================== Main Training Scripts ====================
def train_cartpole():
"""Train CartPole-v1"""
print("\n" + "="*60)
print("Training CartPole-v1")
print("="*60 + "\n")
# Environment parameters
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
env.close()
# Hyperparameters
gamma = 0.95
learning_rate = 0.01
num_episodes = 1000
# Initialize policy and optimizer
policy = CartPolePolicy(state_dim, action_dim).to(device)
optimizer = optim.Adam(policy.parameters(), lr=learning_rate)
# Train
episode_rewards = train_policy_gradient(
'CartPole-v1', policy, optimizer, gamma, num_episodes
)
# Evaluate
print("\nEvaluating trained policy...")
eval_rewards = evaluate_policy('CartPole-v1', policy, num_episodes=500)
# Plot results
plot_results(episode_rewards, eval_rewards, 'CartPole-v1', 'cartpole')
# Save model
torch.save(policy.state_dict(), 'cartpole_policy.pth')
print("\nModel saved as 'cartpole_policy.pth'")
return policy, episode_rewards, eval_rewards
def train_pong():
"""Train Pong-v5"""
print("\n" + "="*60)
print("Training Pong-v5")
print("="*60 + "\n")
# Hyperparameters
gamma = 0.99
learning_rate = 0.001 # Lower learning rate for stability
num_episodes = 1000 # Pong requires more episodes
# Action mapping: policy outputs 0 or 1, map to RIGHT(2) or LEFT(3)
action_map = [2, 3] # Index 0->RIGHT(2), Index 1->LEFT(3)
# Initialize policy and optimizer
policy = PongPolicy(action_dim=2).to(device)
optimizer = optim.Adam(policy.parameters(), lr=learning_rate)
print(f"Using learning rate: {learning_rate} (reduced for stability)")
print(f"Action mapping: 0->RIGHT(2), 1->LEFT(3)")
print(f"Gradient clipping: max_norm=1.0")
print(f"Weight initialization: Kaiming (Conv) + Xavier (FC)\n")
# Train with periodic checkpointing
print("Starting training (checkpoints saved every 500 episodes)...\n")
episode_rewards = train_policy_gradient(
'ALE/Pong-v5', policy, optimizer, gamma, num_episodes,
is_pong=True, action_map=action_map
)
print("\nTraining completed!")
# Evaluate
print("\nEvaluating trained policy...")
eval_rewards = evaluate_policy(
'ALE/Pong-v5', policy, num_episodes=500,
is_pong=True, action_map=action_map
)
# Plot results
plot_results(episode_rewards, eval_rewards, 'Pong-v5', 'pong')
# Save model
torch.save(policy.state_dict(), 'pong_policy.pth')
print("\nModel saved as 'pong_policy.pth'")
return policy, episode_rewards, eval_rewards
# ==================== Run Training ====================
if __name__ == "__main__":
# Train CartPole
#cartpole_policy, cartpole_train_rewards, cartpole_eval_rewards = train_cartpole()
# Train Pong (this will take longer)
#print("\n\nNote: Pong training will take significantly longer (may take hours)")
#print("You may want to reduce num_episodes if just testing the code.\n")
# Uncomment the line below to train Pong
pong_policy, pong_train_rewards, pong_eval_rewards = train_pong()