NCERL-Diverse-PCG / src /drl /ac_models.py
baiyanlali-zhao's picture
init
eaf2e33
import torch.nn.functional as F
from stable_baselines3.common.utils import polyak_update
from torch.optim import Adam
from src.drl.nets import *
class SoftActor:
def __init__(self, net_constructor, tar_ent=None):
self.net = net_constructor()
self.optimiser = Adam(self.net.parameters(), 3e-4)
self.alpha_coe = LearnableLogCoeffient()
self.alpha_optimiser = Adam(self.alpha_coe.parameters(), 3e-4)
self.tar_ent = -self.net.act_dim if tar_ent is None else tar_ent
self.device = 'cpu'
pass
def to(self, device):
self.net.to(device)
self.alpha_coe.to(device)
self.device = device
def eval(self):
self.net.eval()
def train(self):
self.net.train_NCESAC()
def forward(self, obs, grad=True, deterministic=False):
if grad:
return self.net(obs, deterministic)
with torch.no_grad():
return self.net(obs, deterministic)
def backward_policy(self, critic, obs, coe=1.):
acts, logps = self.forward(obs)
qvalues = critic.forward(obs, acts)
a_loss = coe * (self.alpha_coe(logps, grad=False) - qvalues).mean()
a_loss.backward()
def backward_alpha(self, obs):
_, logps = self.forward(obs, grad=False)
loss_alpha = -(self.alpha_coe(logps + self.tar_ent)).mean()
loss_alpha.backward()
pass
def zero_grads(self):
self.optimiser.zero_grad()
self.alpha_optimiser.zero_grad()
def grad_step(self):
self.optimiser.step()
self.alpha_optimiser.step()
def get_nn_arch_str(self):
return str(self.net) + '\n'
class SoftDoubleClipCriticQ:
def __init__(self, nn_constructor, gamma=0.99, tau=0.005):
self.net1 = nn_constructor()
self.net2 = nn_constructor()
self.tar_net1 = nn_constructor()
self.tar_net2 = nn_constructor()
self.tar_net1.load_state_dict(self.net1.state_dict())
self.tar_net2.load_state_dict(self.net2.state_dict())
self.opt1 = Adam(self.net1.parameters(), 3e-4)
self.opt2 = Adam(self.net2.parameters(), 3e-4)
self.device = 'cpu'
self.gamma = gamma
self.tau = tau
def to(self, device):
self.net1.to(device)
self.net2.to(device)
self.tar_net1.to(device)
self.tar_net2.to(device)
self.device = device
def forward(self, obs, acts, grad=True, tar=False):
def foo():
if tar:
q1 = self.tar_net1(obs, acts)
q2 = self.tar_net2(obs, acts)
else:
q1 = self.net1(obs, acts)
q2 = self.net2(obs, acts)
return torch.minimum(q1, q2)
if grad:
return foo()
with torch.no_grad():
return foo()
def compute_target(self, actor, rews, ops):
aps, logpi_aps = actor.forward(ops, grad=False)
qps = self.forward(ops, aps, tar=True, grad=False)
y = rews + self.gamma * (qps - actor.alpha_coe(logpi_aps, False))
return y
def backward_mse(self, actor, obs, acts, rews, ops):
y = self.compute_target(actor, rews, ops)
loss1 = F.mse_loss(self.net1(obs, acts), y)
loss2 = F.mse_loss(self.net2(obs, acts), y)
loss1.backward()
loss2.backward()
def update_tarnet(self):
polyak_update(self.net1.parameters(), self.tar_net1.parameters(), self.tau)
polyak_update(self.net2.parameters(), self.tar_net2.parameters(), self.tau)
def zero_grads(self):
self.opt1.zero_grad()
self.opt2.zero_grad()
def grad_step(self):
self.opt1.step()
self.opt2.step()
def get_nn_arch_str(self):
return str(self.net1) + '\n' + str(self.net2) + '\n'
class MERegMixSoftActor(SoftActor):
def __init__(self, net_constructor, me_reg, tar_ent=None):
super(MERegMixSoftActor, self).__init__(net_constructor, tar_ent)
self.me_reg = me_reg
def forward(self, obs, grad=True, mono=True):
if grad:
return self.net(obs, mono)
with torch.no_grad():
return self.net(obs, mono)
def backward_me_reg(self, critic_W, obs):
muss, stdss, betas = self.net.get_intermediate(obs)
loss1 = -torch.mean(self.me_reg.forward(muss, stdss, betas))
actss, _, _ = esmb_sample(muss, stdss, betas, mono=False)
wvaluess = critic_W.forward(obs, actss)
loss2 = -(betas * wvaluess).mean()
loss = loss1 + loss2
loss.backward()
pass
def backward_policy(self, critic, obs):
actss, logpss, betas = self.forward(obs, mono=False)
qvaluess = critic.forward(obs, actss)
a_loss = (betas * (self.alpha_coe(logpss, grad=False) - qvaluess)).mean()
a_loss.backward()
pass
def backward_alpha(self, obs):
_, logps = self.forward(obs, grad=False)
loss_alpha = -(self.alpha_coe(logps + self.tar_ent)).mean()
loss_alpha.backward()
pass
class MERegSoftDoubleClipCriticQ(SoftDoubleClipCriticQ):
def forward(self, obs, actss, grad=True, tar=False):
def foo():
obss = torch.unsqueeze(obs, dim=1).expand(-1, actss.shape[1], -1)
if tar:
q1 = self.tar_net1(obss, actss)
q2 = self.tar_net2(obss, actss)
else:
q1 = self.net1(obss, actss)
q2 = self.net2(obss, actss)
return torch.minimum(q1, q2)
if grad:
return foo()
with torch.no_grad():
return foo()
def compute_target(self, actor, rews, ops):
apss, logpss, betaps = actor.forward(ops, grad=False, mono=False)
qpss = self.forward(ops, apss, tar=True, grad=False)
qps = (betaps * (qpss.squeeze() - actor.alpha_coe(logpss, grad=False))).sum(dim=-1)
y = rews + self.gamma * qps
return y
pass
class MERegDoubleClipCriticW(MERegSoftDoubleClipCriticQ):
def compute_target(self, actor, rews, ops):
with torch.no_grad():
mupss, stdpss, betaps = actor.net.get_intermediate(ops)
me_regps = actor.me_reg.forward(mupss, stdpss, betaps)
apss, *_ = esmb_sample(mupss, stdpss, betaps, mono=False)
wpss = self.forward(ops, apss, tar=True, grad=False)
y = me_regps + (betaps * wpss).sum(dim=-1)
return self.gamma * y
pass
if __name__ == '__main__':
a = torch.tensor([[[1., 1.], [2., 2.], [3., 3.]]], requires_grad=True)
b = a.detach().mean(-1)
print(a)
print(b)
pass