RL_Car_Agent / ppo_agent_gradio.py
IncreasingLoss's picture
Upload folder using huggingface_hub
24906be verified
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
class PolicyNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, action_dim)
)
for layer in self.net:
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
nn.init.constant_(layer.bias, 0)
def forward(self, state):
return self.net(state)
class ValueNetwork(nn.Module):
def __init__(self, state_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 1)
)
for layer in self.net:
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
nn.init.constant_(layer.bias, 0)
def forward(self, state):
return self.net(state)
class PPOAgent:
def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, clip_eps=0.2):
self.policy = PolicyNetwork(state_dim, action_dim)
self.value = ValueNetwork(state_dim)
self.optimizer = optim.Adam(
list(self.policy.parameters()) + list(self.value.parameters()),
lr=lr
)
self.memory = []
self.gamma = gamma
self.clip_eps = clip_eps
self.epsilon = 1.0
self.epsilon_decay = 0.995
self.epsilon_min = 0.1
self.mini_batch_size = 256
self.update_epochs = 4
def select_action(self, state):
state = torch.FloatTensor(state).unsqueeze(0)
# Add input validation
if torch.isnan(state).any() or torch.isinf(state).any():
print(f"Warning: Invalid state detected: {state}")
action = np.random.randint(0, 8)
log_prob = torch.log(torch.tensor(1/8))
return action, log_prob
with torch.no_grad():
logits = self.policy(state)
if torch.isnan(logits).any() or torch.isinf(logits).any():
print(f"Warning: Invalid logits detected: {logits}")
action = np.random.randint(0, 8)
log_prob = torch.log(torch.tensor(1/8))
return action, log_prob
# Apply temperature scaling
logits = logits / 1.0
probs = torch.softmax(logits, dim=-1)
# Add small epsilon to prevent zero probabilities
probs = probs + 1e-8
probs = probs / probs.sum(dim=-1, keepdim=True)
dist = Categorical(probs)
action = dist.sample()
return action.item(), dist.log_prob(action)
def store_experience(self, state, action, reward, next_state, done, log_prob):
# Validate inputs before storing
if not (torch.isnan(torch.FloatTensor(state)).any() or
torch.isnan(torch.FloatTensor(next_state)).any() or
torch.isnan(log_prob).any() or
np.isnan(reward)):
self.memory.append((state, action, reward, next_state, done, log_prob))
def calculate_returns_and_advantages_optimized(self, rewards, values, dones):
"""
Optimized O(n) calculation of returns and advantages using GAE
"""
returns = torch.zeros_like(rewards)
advantages = torch.zeros_like(rewards)
# GAE parameters
gae_lambda = 0.95
# Calculate advantages and returns in reverse order (O(n))
last_gae = 0
for i in reversed(range(len(rewards))):
if i == len(rewards) - 1:
next_value = 0 # Terminal state
else:
next_value = values[i + 1]
# TD error
delta = rewards[i] + self.gamma * next_value - values[i]
# GAE advantage
advantages[i] = last_gae = delta + self.gamma * gae_lambda * last_gae
# Returns = advantages + values
returns[i] = advantages[i] + values[i]
return returns, advantages
def update(self, epochs=None):
if len(self.memory) < 32: # Minimum batch size
return None
# Use instance variable if epochs not specified
if epochs is None:
epochs = self.update_epochs
# Convert experience buffer to tensors
states, actions, rewards, next_states, dones, old_log_probs = zip(*self.memory)
states = torch.FloatTensor(np.array(states))
actions = torch.LongTensor(actions)
rewards = torch.FloatTensor(rewards)
old_log_probs = torch.stack(old_log_probs)
# Validate tensors
if (torch.isnan(states).any() or torch.isnan(rewards).any() or
torch.isnan(old_log_probs).any()):
print("Warning: NaN detected in experience buffer, skipping update")
self.memory.clear()
return None
# Get values for all states at once
with torch.no_grad():
values = self.value(states).squeeze(-1)
# OPTIMIZED: O(n) returns and advantages calculation
returns, advantages = self.calculate_returns_and_advantages_optimized(
rewards, values, torch.zeros_like(rewards) # dones placeholder
)
# Normalize advantages
if len(advantages) > 1 and advantages.std() > 1e-8:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Mini-batch training for large datasets
dataset_size = len(states)
indices = torch.randperm(dataset_size)
total_policy_loss = 0
total_value_loss = 0
total_entropy = 0
update_count = 0
for epoch in range(epochs):
# Process in mini-batches to avoid memory issues
for start_idx in range(0, dataset_size, self.mini_batch_size):
end_idx = min(start_idx + self.mini_batch_size, dataset_size)
batch_indices = indices[start_idx:end_idx]
# Get mini-batch
batch_states = states[batch_indices]
batch_actions = actions[batch_indices]
batch_advantages = advantages[batch_indices]
batch_returns = returns[batch_indices]
batch_old_log_probs = old_log_probs[batch_indices]
# Forward pass
logits = self.policy(batch_states)
probs = torch.softmax(logits, dim=-1) + 1e-8
probs = probs / probs.sum(dim=-1, keepdim=True)
dist = Categorical(probs)
new_log_probs = dist.log_prob(batch_actions)
# Policy loss
ratio = torch.exp(new_log_probs - batch_old_log_probs)
ratio = torch.clamp(ratio, 0.1, 10.0)
surr1 = ratio * batch_advantages
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * batch_advantages
policy_loss = -torch.min(surr1, surr2).mean()
# Value loss
current_values = self.value(batch_states).squeeze(-1)
value_loss = nn.MSELoss()(current_values, batch_returns)
# Entropy
entropy = dist.entropy().mean()
# Total loss
total_loss = policy_loss + 0.5 * value_loss - 0.01 * entropy # Reduced entropy coefficient
# Check for NaN
if torch.isnan(total_loss) or torch.isinf(total_loss):
print("Warning: NaN/Inf loss detected, skipping batch")
continue
# Optimize
self.optimizer.zero_grad()
total_loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
torch.nn.utils.clip_grad_norm_(self.value.parameters(), 0.5)
self.optimizer.step()
# Accumulate losses
total_policy_loss += policy_loss.detach().item()
total_value_loss += value_loss.detach().item()
total_entropy += entropy.detach().item()
update_count += 1
# Clear memory
self.memory.clear()
if update_count == 0:
return None
return {
'policy_loss': total_policy_loss / update_count,
'value_loss': total_value_loss / update_count,
'entropy': total_entropy / update_count
}
def save(self, filepath):
"""Save the model"""
torch.save({
'policy_state_dict': self.policy.state_dict(),
'value_state_dict': self.value.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict()
}, filepath)
def load(self, filepath):
"""Load the model"""
checkpoint = torch.load(filepath, map_location='cpu')
self.policy.load_state_dict(checkpoint['policy_state_dict'])
self.value.load_state_dict(checkpoint['value_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])