HungNP
New single commit message
cb80c28
raw
history blame contribute delete
No virus
6.99 kB
import logging
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
class PolicyNet(torch.nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim, action_bound):
super(PolicyNet, self).__init__()
self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, action_dim)
self.action_bound = action_bound
def forward(self, x):
x = F.relu(self.fc1(x))
return torch.tanh(self.fc2(x)) * self.action_bound
class RMMPolicyNet(torch.nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim):
super(RMMPolicyNet, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, action_dim),
)
self.fc2 = nn.Sequential(
nn.Linear(state_dim+action_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, action_dim),
)
def forward(self, x):
a1 = torch.sigmoid(self.fc1(x))
x = torch.cat([x,a1],dim=1)
a2 = torch.tanh(self.fc2(x))
return torch.cat([a1,a2],dim=1)
class QValueNet(torch.nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim):
super(QValueNet, self).__init__()
self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, 1)
def forward(self, x, a):
cat = torch.cat([x, a], dim=1)
x = F.relu(self.fc1(cat))
return self.fc2(x)
class TwoLayerFC(torch.nn.Module):
def __init__(
self, num_in, num_out, hidden_dim, activation=F.relu, out_fn=lambda x: x
):
super().__init__()
self.fc1 = nn.Linear(num_in, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, num_out)
self.activation = activation
self.out_fn = out_fn
def forward(self, x):
x = self.activation(self.fc1(x))
x = self.activation(self.fc2(x))
x = self.out_fn(self.fc3(x))
return x
class DDPG:
"""DDPG algo"""
def __init__(
self,
num_in_actor,
num_out_actor,
num_in_critic,
hidden_dim,
discrete,
action_bound,
sigma,
actor_lr,
critic_lr,
tau,
gamma,
device,
use_rmm=True,
):
out_fn = (lambda x: x) if discrete else (lambda x: torch.tanh(x) * action_bound)
if use_rmm:
self.actor = RMMPolicyNet(
num_in_actor,
hidden_dim,
num_out_actor,
).to(device)
self.target_actor = RMMPolicyNet(
num_in_actor,
hidden_dim,
num_out_actor,
).to(device)
else:
self.actor = TwoLayerFC(
num_in_actor,
num_out_actor,
hidden_dim,
activation=F.relu,
out_fn=out_fn,
).to(device)
self.target_actor = TwoLayerFC(
num_in_actor,
num_out_actor,
hidden_dim,
activation=F.relu,
out_fn=out_fn,
).to(device)
self.critic = TwoLayerFC(num_in_critic, 1, hidden_dim).to(device)
self.target_critic = TwoLayerFC(num_in_critic, 1, hidden_dim).to(device)
self.target_critic.load_state_dict(self.critic.state_dict())
self.target_actor.load_state_dict(self.actor.state_dict())
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
self.gamma = gamma
self.sigma = sigma
self.action_bound = action_bound
self.tau = tau
self.action_dim = num_out_actor
self.device = device
def take_action(self, state):
state = torch.tensor(np.expand_dims(state,0), dtype=torch.float).to(self.device)
action = self.actor(state)[0].detach().cpu().numpy()
action = action + self.sigma * np.random.randn(self.action_dim)
action[0]=np.clip(action[0],0,1)
action[1]=np.clip(action[1],-1,1)
return action
def save_state_dict(self,name):
dicts = {
"critic":self.critic.state_dict(),
"target_critic":self.target_critic.state_dict(),
"actor":self.actor.state_dict(),
"target_actor":self.target_actor.state_dict()
}
torch.save(dicts,name)
def load_state_dict(self,name):
dicts = torch.load(name)
self.critic.load_state_dict(dicts["critic"])
self.target_critic.load_state_dict(dicts["target_critic"])
self.actor.load_state_dict(dicts["actor"])
self.target_actor.load_state_dict(dicts["target_actor"])
def soft_update(self, net, target_net):
for param_target, param in zip(target_net.parameters(), net.parameters()):
param_target.data.copy_(
param_target.data * (1.0 - self.tau) + param.data * self.tau
)
def update(self, transition_dict):
states = torch.tensor(transition_dict["states"], dtype=torch.float).to(
self.device
)
actions = (
torch.tensor(transition_dict["actions"], dtype=torch.float)
.to(self.device)
)
rewards = (
torch.tensor(transition_dict["rewards"], dtype=torch.float)
.view(-1, 1)
.to(self.device)
)
next_states = torch.tensor(
transition_dict["next_states"], dtype=torch.float
).to(self.device)
dones = (
torch.tensor(transition_dict["dones"], dtype=torch.float)
.view(-1, 1)
.to(self.device)
)
next_q_values = self.target_critic(
torch.cat([next_states, self.target_actor(next_states)], dim=1)
)
q_targets = rewards + self.gamma * next_q_values * (1 - dones)
critic_loss = torch.mean(
F.mse_loss(
self.critic(torch.cat([states, actions], dim=1)),
q_targets,
)
)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
actor_loss = -torch.mean(
self.critic(
torch.cat([states, self.actor(states)], dim=1)
)
)
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
logging.info(f"update DDPG: actor loss {actor_loss.item():.3f}, critic loss {critic_loss.item():.3f}, ")
self.soft_update(self.actor, self.target_actor) # soft-update the target policy net
self.soft_update(self.critic, self.target_critic) # soft-update the target Q value net