Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.distributions import Categorical | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.distributions import Categorical | |
import numpy as np | |
class PolicyNetwork(nn.Module): | |
def __init__(self, state_dim, action_dim): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(state_dim, 512), | |
nn.ReLU(), | |
nn.Linear(512, 256), | |
nn.ReLU(), | |
nn.Linear(256, 128), | |
nn.ReLU(), | |
nn.Linear(128, action_dim) | |
) | |
for layer in self.net: | |
if isinstance(layer, nn.Linear): | |
nn.init.xavier_uniform_(layer.weight) | |
nn.init.constant_(layer.bias, 0) | |
def forward(self, state): | |
return self.net(state) | |
class ValueNetwork(nn.Module): | |
def __init__(self, state_dim): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(state_dim, 512), | |
nn.ReLU(), | |
nn.Linear(512, 256), | |
nn.ReLU(), | |
nn.Linear(256, 128), | |
nn.ReLU(), | |
nn.Linear(128, 1) | |
) | |
for layer in self.net: | |
if isinstance(layer, nn.Linear): | |
nn.init.xavier_uniform_(layer.weight) | |
nn.init.constant_(layer.bias, 0) | |
def forward(self, state): | |
return self.net(state) | |
class PPOAgent: | |
def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, clip_eps=0.2): | |
self.policy = PolicyNetwork(state_dim, action_dim) | |
self.value = ValueNetwork(state_dim) | |
self.optimizer = optim.Adam( | |
list(self.policy.parameters()) + list(self.value.parameters()), | |
lr=lr | |
) | |
self.memory = [] | |
self.gamma = gamma | |
self.clip_eps = clip_eps | |
self.epsilon = 1.0 | |
self.epsilon_decay = 0.995 | |
self.epsilon_min = 0.1 | |
self.mini_batch_size = 256 | |
self.update_epochs = 4 | |
def select_action(self, state): | |
state = torch.FloatTensor(state).unsqueeze(0) | |
# Add input validation | |
if torch.isnan(state).any() or torch.isinf(state).any(): | |
print(f"Warning: Invalid state detected: {state}") | |
action = np.random.randint(0, 8) | |
log_prob = torch.log(torch.tensor(1/8)) | |
return action, log_prob | |
with torch.no_grad(): | |
logits = self.policy(state) | |
if torch.isnan(logits).any() or torch.isinf(logits).any(): | |
print(f"Warning: Invalid logits detected: {logits}") | |
action = np.random.randint(0, 8) | |
log_prob = torch.log(torch.tensor(1/8)) | |
return action, log_prob | |
# Apply temperature scaling | |
logits = logits / 1.0 | |
probs = torch.softmax(logits, dim=-1) | |
# Add small epsilon to prevent zero probabilities | |
probs = probs + 1e-8 | |
probs = probs / probs.sum(dim=-1, keepdim=True) | |
dist = Categorical(probs) | |
action = dist.sample() | |
return action.item(), dist.log_prob(action) | |
def store_experience(self, state, action, reward, next_state, done, log_prob): | |
# Validate inputs before storing | |
if not (torch.isnan(torch.FloatTensor(state)).any() or | |
torch.isnan(torch.FloatTensor(next_state)).any() or | |
torch.isnan(log_prob).any() or | |
np.isnan(reward)): | |
self.memory.append((state, action, reward, next_state, done, log_prob)) | |
def calculate_returns_and_advantages_optimized(self, rewards, values, dones): | |
""" | |
Optimized O(n) calculation of returns and advantages using GAE | |
""" | |
returns = torch.zeros_like(rewards) | |
advantages = torch.zeros_like(rewards) | |
# GAE parameters | |
gae_lambda = 0.95 | |
# Calculate advantages and returns in reverse order (O(n)) | |
last_gae = 0 | |
for i in reversed(range(len(rewards))): | |
if i == len(rewards) - 1: | |
next_value = 0 # Terminal state | |
else: | |
next_value = values[i + 1] | |
# TD error | |
delta = rewards[i] + self.gamma * next_value - values[i] | |
# GAE advantage | |
advantages[i] = last_gae = delta + self.gamma * gae_lambda * last_gae | |
# Returns = advantages + values | |
returns[i] = advantages[i] + values[i] | |
return returns, advantages | |
def update(self, epochs=None): | |
if len(self.memory) < 32: # Minimum batch size | |
return None | |
# Use instance variable if epochs not specified | |
if epochs is None: | |
epochs = self.update_epochs | |
# Convert experience buffer to tensors | |
states, actions, rewards, next_states, dones, old_log_probs = zip(*self.memory) | |
states = torch.FloatTensor(np.array(states)) | |
actions = torch.LongTensor(actions) | |
rewards = torch.FloatTensor(rewards) | |
old_log_probs = torch.stack(old_log_probs) | |
# Validate tensors | |
if (torch.isnan(states).any() or torch.isnan(rewards).any() or | |
torch.isnan(old_log_probs).any()): | |
print("Warning: NaN detected in experience buffer, skipping update") | |
self.memory.clear() | |
return None | |
# Get values for all states at once | |
with torch.no_grad(): | |
values = self.value(states).squeeze(-1) | |
# OPTIMIZED: O(n) returns and advantages calculation | |
returns, advantages = self.calculate_returns_and_advantages_optimized( | |
rewards, values, torch.zeros_like(rewards) # dones placeholder | |
) | |
# Normalize advantages | |
if len(advantages) > 1 and advantages.std() > 1e-8: | |
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) | |
# Mini-batch training for large datasets | |
dataset_size = len(states) | |
indices = torch.randperm(dataset_size) | |
total_policy_loss = 0 | |
total_value_loss = 0 | |
total_entropy = 0 | |
update_count = 0 | |
for epoch in range(epochs): | |
# Process in mini-batches to avoid memory issues | |
for start_idx in range(0, dataset_size, self.mini_batch_size): | |
end_idx = min(start_idx + self.mini_batch_size, dataset_size) | |
batch_indices = indices[start_idx:end_idx] | |
# Get mini-batch | |
batch_states = states[batch_indices] | |
batch_actions = actions[batch_indices] | |
batch_advantages = advantages[batch_indices] | |
batch_returns = returns[batch_indices] | |
batch_old_log_probs = old_log_probs[batch_indices] | |
# Forward pass | |
logits = self.policy(batch_states) | |
probs = torch.softmax(logits, dim=-1) + 1e-8 | |
probs = probs / probs.sum(dim=-1, keepdim=True) | |
dist = Categorical(probs) | |
new_log_probs = dist.log_prob(batch_actions) | |
# Policy loss | |
ratio = torch.exp(new_log_probs - batch_old_log_probs) | |
ratio = torch.clamp(ratio, 0.1, 10.0) | |
surr1 = ratio * batch_advantages | |
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * batch_advantages | |
policy_loss = -torch.min(surr1, surr2).mean() | |
# Value loss | |
current_values = self.value(batch_states).squeeze(-1) | |
value_loss = nn.MSELoss()(current_values, batch_returns) | |
# Entropy | |
entropy = dist.entropy().mean() | |
# Total loss | |
total_loss = policy_loss + 0.5 * value_loss - 0.01 * entropy # Reduced entropy coefficient | |
# Check for NaN | |
if torch.isnan(total_loss) or torch.isinf(total_loss): | |
print("Warning: NaN/Inf loss detected, skipping batch") | |
continue | |
# Optimize | |
self.optimizer.zero_grad() | |
total_loss.backward() | |
# Gradient clipping | |
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5) | |
torch.nn.utils.clip_grad_norm_(self.value.parameters(), 0.5) | |
self.optimizer.step() | |
# Accumulate losses | |
total_policy_loss += policy_loss.detach().item() | |
total_value_loss += value_loss.detach().item() | |
total_entropy += entropy.detach().item() | |
update_count += 1 | |
# Clear memory | |
self.memory.clear() | |
if update_count == 0: | |
return None | |
return { | |
'policy_loss': total_policy_loss / update_count, | |
'value_loss': total_value_loss / update_count, | |
'entropy': total_entropy / update_count | |
} | |
def save(self, filepath): | |
"""Save the model""" | |
torch.save({ | |
'policy_state_dict': self.policy.state_dict(), | |
'value_state_dict': self.value.state_dict(), | |
'optimizer_state_dict': self.optimizer.state_dict() | |
}, filepath) | |
def load(self, filepath): | |
"""Load the model""" | |
checkpoint = torch.load(filepath, map_location='cpu') | |
self.policy.load_state_dict(checkpoint['policy_state_dict']) | |
self.value.load_state_dict(checkpoint['value_state_dict']) | |
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |