import logging import torch from torch import nn import torch.nn.functional as F import numpy as np class PolicyNet(torch.nn.Module): def __init__(self, state_dim, hidden_dim, action_dim, action_bound): super(PolicyNet, self).__init__() self.fc1 = torch.nn.Linear(state_dim, hidden_dim) self.fc2 = torch.nn.Linear(hidden_dim, action_dim) self.action_bound = action_bound def forward(self, x): x = F.relu(self.fc1(x)) return torch.tanh(self.fc2(x)) * self.action_bound class RMMPolicyNet(torch.nn.Module): def __init__(self, state_dim, hidden_dim, action_dim): super(RMMPolicyNet, self).__init__() self.fc1 = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, action_dim), ) self.fc2 = nn.Sequential( nn.Linear(state_dim+action_dim, hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, action_dim), ) def forward(self, x): a1 = torch.sigmoid(self.fc1(x)) x = torch.cat([x,a1],dim=1) a2 = torch.tanh(self.fc2(x)) return torch.cat([a1,a2],dim=1) class QValueNet(torch.nn.Module): def __init__(self, state_dim, hidden_dim, action_dim): super(QValueNet, self).__init__() self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim) self.fc2 = torch.nn.Linear(hidden_dim, 1) def forward(self, x, a): cat = torch.cat([x, a], dim=1) x = F.relu(self.fc1(cat)) return self.fc2(x) class TwoLayerFC(torch.nn.Module): def __init__( self, num_in, num_out, hidden_dim, activation=F.relu, out_fn=lambda x: x ): super().__init__() self.fc1 = nn.Linear(num_in, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, num_out) self.activation = activation self.out_fn = out_fn def forward(self, x): x = self.activation(self.fc1(x)) x = self.activation(self.fc2(x)) x = self.out_fn(self.fc3(x)) return x class DDPG: """DDPG algo""" def __init__( self, num_in_actor, num_out_actor, num_in_critic, hidden_dim, discrete, action_bound, sigma, actor_lr, critic_lr, tau, gamma, device, use_rmm=True, ): out_fn = (lambda x: x) if discrete else (lambda x: torch.tanh(x) * action_bound) if use_rmm: self.actor = RMMPolicyNet( num_in_actor, hidden_dim, num_out_actor, ).to(device) self.target_actor = RMMPolicyNet( num_in_actor, hidden_dim, num_out_actor, ).to(device) else: self.actor = TwoLayerFC( num_in_actor, num_out_actor, hidden_dim, activation=F.relu, out_fn=out_fn, ).to(device) self.target_actor = TwoLayerFC( num_in_actor, num_out_actor, hidden_dim, activation=F.relu, out_fn=out_fn, ).to(device) self.critic = TwoLayerFC(num_in_critic, 1, hidden_dim).to(device) self.target_critic = TwoLayerFC(num_in_critic, 1, hidden_dim).to(device) self.target_critic.load_state_dict(self.critic.state_dict()) self.target_actor.load_state_dict(self.actor.state_dict()) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) self.gamma = gamma self.sigma = sigma self.action_bound = action_bound self.tau = tau self.action_dim = num_out_actor self.device = device def take_action(self, state): state = torch.tensor(np.expand_dims(state,0), dtype=torch.float).to(self.device) action = self.actor(state)[0].detach().cpu().numpy() action = action + self.sigma * np.random.randn(self.action_dim) action[0]=np.clip(action[0],0,1) action[1]=np.clip(action[1],-1,1) return action def save_state_dict(self,name): dicts = { "critic":self.critic.state_dict(), "target_critic":self.target_critic.state_dict(), "actor":self.actor.state_dict(), "target_actor":self.target_actor.state_dict() } torch.save(dicts,name) def load_state_dict(self,name): dicts = torch.load(name) self.critic.load_state_dict(dicts["critic"]) self.target_critic.load_state_dict(dicts["target_critic"]) self.actor.load_state_dict(dicts["actor"]) self.target_actor.load_state_dict(dicts["target_actor"]) def soft_update(self, net, target_net): for param_target, param in zip(target_net.parameters(), net.parameters()): param_target.data.copy_( param_target.data * (1.0 - self.tau) + param.data * self.tau ) def update(self, transition_dict): states = torch.tensor(transition_dict["states"], dtype=torch.float).to( self.device ) actions = ( torch.tensor(transition_dict["actions"], dtype=torch.float) .to(self.device) ) rewards = ( torch.tensor(transition_dict["rewards"], dtype=torch.float) .view(-1, 1) .to(self.device) ) next_states = torch.tensor( transition_dict["next_states"], dtype=torch.float ).to(self.device) dones = ( torch.tensor(transition_dict["dones"], dtype=torch.float) .view(-1, 1) .to(self.device) ) next_q_values = self.target_critic( torch.cat([next_states, self.target_actor(next_states)], dim=1) ) q_targets = rewards + self.gamma * next_q_values * (1 - dones) critic_loss = torch.mean( F.mse_loss( self.critic(torch.cat([states, actions], dim=1)), q_targets, ) ) self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() actor_loss = -torch.mean( self.critic( torch.cat([states, self.actor(states)], dim=1) ) ) self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() logging.info(f"update DDPG: actor loss {actor_loss.item():.3f}, critic loss {critic_loss.item():.3f}, ") self.soft_update(self.actor, self.target_actor) # soft-update the target policy net self.soft_update(self.critic, self.target_critic) # soft-update the target Q value net