kevinwang676 commited on
Commit
7101384
·
verified ·
1 Parent(s): 41a111d

Create test2.py

Browse files
Files changed (1) hide show
  1. test2.py +449 -0
test2.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torch.nn.functional as F
5
+ from torch.distributions import Categorical
6
+ import numpy as np
7
+ import ale_py
8
+ import gymnasium as gym
9
+ import matplotlib.pyplot as plt
10
+ from collections import deque
11
+
12
+ # Register ALE environments
13
+ gym.register_envs(ale_py)
14
+
15
+ # Set random seeds for reproducibility
16
+ torch.manual_seed(42)
17
+ np.random.seed(42)
18
+
19
+ # Check if GPU is available
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ print(f"Using device: {device}")
22
+
23
+
24
+ # ==================== Policy Networks ====================
25
+
26
+ class CartPolePolicy(nn.Module):
27
+ """Policy network for CartPole environment"""
28
+ def __init__(self, state_dim, action_dim, hidden_dim=128):
29
+ super(CartPolePolicy, self).__init__()
30
+ self.fc1 = nn.Linear(state_dim, hidden_dim)
31
+ self.fc2 = nn.Linear(hidden_dim, action_dim)
32
+
33
+ # Initialize weights
34
+ self._initialize_weights()
35
+
36
+ def _initialize_weights(self):
37
+ """Initialize network weights"""
38
+ for m in self.modules():
39
+ if isinstance(m, nn.Linear):
40
+ nn.init.xavier_uniform_(m.weight)
41
+ nn.init.constant_(m.bias, 0.0)
42
+
43
+ def forward(self, x):
44
+ x = F.relu(self.fc1(x))
45
+ x = self.fc2(x)
46
+ return F.softmax(x, dim=-1)
47
+
48
+
49
+ class PongPolicy(nn.Module):
50
+ """Policy network for Pong with CNN architecture"""
51
+ def __init__(self, action_dim=2):
52
+ super(PongPolicy, self).__init__()
53
+ # CNN layers for processing 80x80 images
54
+ self.conv1 = nn.Conv2d(1, 16, kernel_size=8, stride=4)
55
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=4, stride=2)
56
+
57
+ # Calculate size after convolutions: 80 -> 19 -> 8
58
+ self.fc1 = nn.Linear(32 * 8 * 8, 256)
59
+ self.fc2 = nn.Linear(256, action_dim)
60
+
61
+ # Initialize weights for better training stability
62
+ self._initialize_weights()
63
+
64
+ def _initialize_weights(self):
65
+ """Initialize network weights with proper initialization"""
66
+ for m in self.modules():
67
+ if isinstance(m, nn.Conv2d):
68
+ nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
69
+ if m.bias is not None:
70
+ nn.init.constant_(m.bias, 0.0)
71
+ elif isinstance(m, nn.Linear):
72
+ nn.init.xavier_uniform_(m.weight)
73
+ nn.init.constant_(m.bias, 0.0)
74
+
75
+ def forward(self, x):
76
+ # x shape: (batch, 80, 80) -> add channel dimension
77
+ if len(x.shape) == 2:
78
+ x = x.unsqueeze(0).unsqueeze(0)
79
+ elif len(x.shape) == 3:
80
+ x = x.unsqueeze(1)
81
+
82
+ x = F.relu(self.conv1(x))
83
+ x = F.relu(self.conv2(x))
84
+ x = x.view(x.size(0), -1)
85
+ x = F.relu(self.fc1(x))
86
+ x = self.fc2(x)
87
+ return F.softmax(x, dim=-1)
88
+
89
+
90
+ # ==================== Helper Functions ====================
91
+
92
+ def preprocess(image):
93
+ """Prepro 210x160x3 uint8 frame into 6400 (80x80) 2D float array"""
94
+ image = image[35:195] # crop
95
+ image = image[::2, ::2, 0] # downsample by factor of 2
96
+ image[image == 144] = 0 # erase background (background type 1)
97
+ image[image == 109] = 0 # erase background (background type 2)
98
+ image[image != 0] = 1 # everything else (paddles, ball) just set to 1
99
+ return np.reshape(image.astype(float).ravel(), [80, 80])
100
+
101
+
102
+ def compute_returns(rewards, gamma):
103
+ """Compute discounted returns for each timestep"""
104
+ returns = []
105
+ R = 0
106
+ for r in reversed(rewards):
107
+ R = r + gamma * R
108
+ returns.insert(0, R)
109
+ returns = torch.tensor(returns, dtype=torch.float32).to(device)
110
+ # Normalize returns for more stable training
111
+ if len(returns) > 1:
112
+ returns = (returns - returns.mean()) / (returns.std() + 1e-8)
113
+ return returns
114
+
115
+
116
+ def moving_average(data, window_size):
117
+ """Compute moving average"""
118
+ if len(data) < window_size:
119
+ return np.array([np.mean(data[:i+1]) for i in range(len(data))])
120
+
121
+ moving_avg = []
122
+ for i in range(len(data)):
123
+ if i < window_size:
124
+ moving_avg.append(np.mean(data[:i+1]))
125
+ else:
126
+ moving_avg.append(np.mean(data[i-window_size+1:i+1]))
127
+ return np.array(moving_avg)
128
+
129
+
130
+ # ==================== Policy Gradient Algorithm ====================
131
+
132
+ def train_policy_gradient(env_name, policy, optimizer, gamma, num_episodes,
133
+ max_steps=None, is_pong=False, action_map=None):
134
+ """
135
+ Train policy using REINFORCE algorithm
136
+
137
+ Args:
138
+ env_name: Name of the gym environment
139
+ policy: Policy network
140
+ optimizer: PyTorch optimizer
141
+ gamma: Discount factor
142
+ num_episodes: Number of training episodes
143
+ max_steps: Maximum steps per episode (None for default)
144
+ is_pong: Whether this is Pong environment
145
+ action_map: Mapping from policy action to env action (for Pong)
146
+ """
147
+ env = gym.make(env_name)
148
+ episode_rewards = []
149
+
150
+ for episode in range(num_episodes):
151
+ state, _ = env.reset()
152
+
153
+ # Preprocess state for Pong
154
+ if is_pong:
155
+ state = preprocess(state)
156
+ prev_frame = None # Track previous frame for motion
157
+
158
+ log_probs = []
159
+ rewards = []
160
+
161
+ done = False
162
+ step = 0
163
+
164
+ while not done:
165
+ # For Pong, use frame difference (motion signal)
166
+ if is_pong:
167
+ cur_frame = state
168
+ if prev_frame is not None:
169
+ state_input = cur_frame - prev_frame
170
+ else:
171
+ state_input = np.zeros_like(cur_frame, dtype=np.float32)
172
+ prev_frame = cur_frame
173
+ state_tensor = torch.FloatTensor(state_input).to(device)
174
+ else:
175
+ # Convert state to tensor
176
+ state_tensor = torch.FloatTensor(state).to(device)
177
+
178
+ # Get action probabilities
179
+ action_probs = policy(state_tensor)
180
+
181
+ # Sample action from the distribution
182
+ dist = Categorical(action_probs)
183
+ action = dist.sample()
184
+ log_prob = dist.log_prob(action)
185
+
186
+ # Map action for Pong (0,1 -> 2,3)
187
+ if is_pong:
188
+ env_action = action_map[action.item()]
189
+ else:
190
+ env_action = action.item()
191
+
192
+ # Take action in environment
193
+ next_state, reward, terminated, truncated, _ = env.step(env_action)
194
+ done = terminated or truncated
195
+
196
+ # Preprocess next state for Pong
197
+ if is_pong:
198
+ next_state = preprocess(next_state)
199
+
200
+ # Store log probability and reward
201
+ log_probs.append(log_prob)
202
+ rewards.append(reward)
203
+
204
+ state = next_state
205
+ step += 1
206
+
207
+ if max_steps and step >= max_steps:
208
+ break
209
+
210
+ # Compute returns
211
+ returns = compute_returns(rewards, gamma)
212
+
213
+ # Compute policy gradient loss
214
+ policy_loss = []
215
+ for log_prob, R in zip(log_probs, returns):
216
+ policy_loss.append(-log_prob * R)
217
+
218
+ # Optimize policy
219
+ optimizer.zero_grad()
220
+ policy_loss = torch.stack(policy_loss).sum()
221
+ policy_loss.backward()
222
+ # Gradient clipping for training stability
223
+ torch.nn.utils.clip_grad_norm_(policy.parameters(), max_norm=1.0)
224
+ optimizer.step()
225
+
226
+ # Record episode reward
227
+ episode_reward = sum(rewards)
228
+ episode_rewards.append(episode_reward)
229
+
230
+ # Print progress
231
+ if (episode + 1) % 100 == 0:
232
+ avg_reward = np.mean(episode_rewards[-100:])
233
+ print(f"Episode {episode + 1}/{num_episodes}, "
234
+ f"Avg Reward (last 100): {avg_reward:.2f}")
235
+
236
+ # Save checkpoint for Pong every 500 episodes
237
+ if is_pong and (episode + 1) % 500 == 0:
238
+ checkpoint_path = f'pong_checkpoint_ep{episode + 1}.pth'
239
+ torch.save({
240
+ 'episode': episode + 1,
241
+ 'policy_state_dict': policy.state_dict(),
242
+ 'optimizer_state_dict': optimizer.state_dict(),
243
+ 'episode_rewards': episode_rewards,
244
+ }, checkpoint_path)
245
+ print(f" → Checkpoint saved: {checkpoint_path}")
246
+
247
+ env.close()
248
+ return episode_rewards
249
+
250
+
251
+ def evaluate_policy(env_name, policy, num_episodes=500, is_pong=False, action_map=None):
252
+ """Evaluate trained policy over multiple episodes"""
253
+ env = gym.make(env_name)
254
+ eval_rewards = []
255
+
256
+ for episode in range(num_episodes):
257
+ state, _ = env.reset()
258
+
259
+ if is_pong:
260
+ state = preprocess(state)
261
+ prev_frame = None # Track previous frame for motion
262
+
263
+ episode_reward = 0
264
+ done = False
265
+
266
+ while not done:
267
+ # For Pong, use frame difference (motion signal)
268
+ if is_pong:
269
+ cur_frame = state
270
+ if prev_frame is not None:
271
+ state_input = cur_frame - prev_frame
272
+ else:
273
+ state_input = np.zeros_like(cur_frame, dtype=np.float32)
274
+ prev_frame = cur_frame
275
+ state_tensor = torch.FloatTensor(state_input).to(device)
276
+ else:
277
+ state_tensor = torch.FloatTensor(state).to(device)
278
+
279
+ with torch.no_grad():
280
+ action_probs = policy(state_tensor)
281
+ action = torch.argmax(action_probs).item()
282
+
283
+ if is_pong:
284
+ env_action = action_map[action]
285
+ else:
286
+ env_action = action
287
+
288
+ next_state, reward, terminated, truncated, _ = env.step(env_action)
289
+ done = terminated or truncated
290
+
291
+ if is_pong:
292
+ next_state = preprocess(next_state)
293
+
294
+ episode_reward += reward
295
+ state = next_state
296
+
297
+ eval_rewards.append(episode_reward)
298
+
299
+ if (episode + 1) % 100 == 0:
300
+ print(f"Evaluated {episode + 1}/{num_episodes} episodes")
301
+
302
+ env.close()
303
+ return eval_rewards
304
+
305
+
306
+ def plot_results(episode_rewards, eval_rewards, title, save_prefix):
307
+ """Plot training curve and evaluation histogram"""
308
+ fig, axes = plt.subplots(1, 2, figsize=(15, 5))
309
+
310
+ # Plot training curve
311
+ ax1 = axes[0]
312
+ episodes = np.arange(1, len(episode_rewards) + 1)
313
+ ma = moving_average(episode_rewards, 100)
314
+
315
+ ax1.plot(episodes, episode_rewards, alpha=0.3, label='Episode Reward')
316
+ ax1.plot(episodes, ma, linewidth=2, label='Moving Average (100 episodes)')
317
+ ax1.set_xlabel('Episode')
318
+ ax1.set_ylabel('Reward')
319
+ ax1.set_title(f'{title} - Training Curve')
320
+ ax1.legend()
321
+ ax1.grid(True, alpha=0.3)
322
+
323
+ # Plot evaluation histogram
324
+ ax2 = axes[1]
325
+ mean_reward = np.mean(eval_rewards)
326
+ std_reward = np.std(eval_rewards)
327
+
328
+ ax2.hist(eval_rewards, bins=30, edgecolor='black', alpha=0.7)
329
+ ax2.axvline(mean_reward, color='red', linestyle='--', linewidth=2,
330
+ label=f'Mean: {mean_reward:.2f}')
331
+ ax2.set_xlabel('Episode Reward')
332
+ ax2.set_ylabel('Frequency')
333
+ ax2.set_title(f'{title} - Evaluation Histogram (500 episodes)\n'
334
+ f'Mean: {mean_reward:.2f}, Std: {std_reward:.2f}')
335
+ ax2.legend()
336
+ ax2.grid(True, alpha=0.3, axis='y')
337
+
338
+ plt.tight_layout()
339
+ plt.savefig(f'{save_prefix}_results.png', dpi=150, bbox_inches='tight')
340
+ plt.show()
341
+
342
+ print(f"\n{title} Evaluation Results:")
343
+ print(f"Mean Reward: {mean_reward:.2f}")
344
+ print(f"Std Reward: {std_reward:.2f}")
345
+
346
+
347
+ # ==================== Main Training Scripts ====================
348
+
349
+ def train_cartpole():
350
+ """Train CartPole-v1"""
351
+ print("\n" + "="*60)
352
+ print("Training CartPole-v1")
353
+ print("="*60 + "\n")
354
+
355
+ # Environment parameters
356
+ env = gym.make('CartPole-v1')
357
+ state_dim = env.observation_space.shape[0]
358
+ action_dim = env.action_space.n
359
+ env.close()
360
+
361
+ # Hyperparameters
362
+ gamma = 0.95
363
+ learning_rate = 0.01
364
+ num_episodes = 1000
365
+
366
+ # Initialize policy and optimizer
367
+ policy = CartPolePolicy(state_dim, action_dim).to(device)
368
+ optimizer = optim.Adam(policy.parameters(), lr=learning_rate)
369
+
370
+ # Train
371
+ episode_rewards = train_policy_gradient(
372
+ 'CartPole-v1', policy, optimizer, gamma, num_episodes
373
+ )
374
+
375
+ # Evaluate
376
+ print("\nEvaluating trained policy...")
377
+ eval_rewards = evaluate_policy('CartPole-v1', policy, num_episodes=500)
378
+
379
+ # Plot results
380
+ plot_results(episode_rewards, eval_rewards, 'CartPole-v1', 'cartpole')
381
+
382
+ # Save model
383
+ torch.save(policy.state_dict(), 'cartpole_policy.pth')
384
+ print("\nModel saved as 'cartpole_policy.pth'")
385
+
386
+ return policy, episode_rewards, eval_rewards
387
+
388
+
389
+ def train_pong():
390
+ """Train Pong-v5"""
391
+ print("\n" + "="*60)
392
+ print("Training Pong-v5")
393
+ print("="*60 + "\n")
394
+
395
+ # Hyperparameters
396
+ gamma = 0.99
397
+ learning_rate = 0.001 # Lower learning rate for stability
398
+ num_episodes = 1000 # Pong requires more episodes
399
+
400
+ # Action mapping: policy outputs 0 or 1, map to RIGHT(2) or LEFT(3)
401
+ action_map = [2, 3] # Index 0->RIGHT(2), Index 1->LEFT(3)
402
+
403
+ # Initialize policy and optimizer
404
+ policy = PongPolicy(action_dim=2).to(device)
405
+ optimizer = optim.Adam(policy.parameters(), lr=learning_rate)
406
+
407
+ print(f"Using learning rate: {learning_rate} (reduced for stability)")
408
+ print(f"Action mapping: 0->RIGHT(2), 1->LEFT(3)")
409
+ print(f"Gradient clipping: max_norm=1.0")
410
+ print(f"Weight initialization: Kaiming (Conv) + Xavier (FC)\n")
411
+
412
+ # Train with periodic checkpointing
413
+ print("Starting training (checkpoints saved every 500 episodes)...\n")
414
+ episode_rewards = train_policy_gradient(
415
+ 'ALE/Pong-v5', policy, optimizer, gamma, num_episodes,
416
+ is_pong=True, action_map=action_map
417
+ )
418
+
419
+ print("\nTraining completed!")
420
+
421
+ # Evaluate
422
+ print("\nEvaluating trained policy...")
423
+ eval_rewards = evaluate_policy(
424
+ 'ALE/Pong-v5', policy, num_episodes=500,
425
+ is_pong=True, action_map=action_map
426
+ )
427
+
428
+ # Plot results
429
+ plot_results(episode_rewards, eval_rewards, 'Pong-v5', 'pong')
430
+
431
+ # Save model
432
+ torch.save(policy.state_dict(), 'pong_policy.pth')
433
+ print("\nModel saved as 'pong_policy.pth'")
434
+
435
+ return policy, episode_rewards, eval_rewards
436
+
437
+
438
+ # ==================== Run Training ====================
439
+
440
+ if __name__ == "__main__":
441
+ # Train CartPole
442
+ #cartpole_policy, cartpole_train_rewards, cartpole_eval_rewards = train_cartpole()
443
+
444
+ # Train Pong (this will take longer)
445
+ #print("\n\nNote: Pong training will take significantly longer (may take hours)")
446
+ #print("You may want to reduce num_episodes if just testing the code.\n")
447
+
448
+ # Uncomment the line below to train Pong
449
+ pong_policy, pong_train_rewards, pong_eval_rewards = train_pong()