Spaces:
Sleeping
Sleeping
import torch.nn.functional as F | |
from stable_baselines3.common.utils import polyak_update | |
from torch.optim import Adam | |
from src.drl.nets import * | |
class SoftActor: | |
def __init__(self, net_constructor, tar_ent=None): | |
self.net = net_constructor() | |
self.optimiser = Adam(self.net.parameters(), 3e-4) | |
self.alpha_coe = LearnableLogCoeffient() | |
self.alpha_optimiser = Adam(self.alpha_coe.parameters(), 3e-4) | |
self.tar_ent = -self.net.act_dim if tar_ent is None else tar_ent | |
self.device = 'cpu' | |
pass | |
def to(self, device): | |
self.net.to(device) | |
self.alpha_coe.to(device) | |
self.device = device | |
def eval(self): | |
self.net.eval() | |
def train(self): | |
self.net.train_NCESAC() | |
def forward(self, obs, grad=True, deterministic=False): | |
if grad: | |
return self.net(obs, deterministic) | |
with torch.no_grad(): | |
return self.net(obs, deterministic) | |
def backward_policy(self, critic, obs, coe=1.): | |
acts, logps = self.forward(obs) | |
qvalues = critic.forward(obs, acts) | |
a_loss = coe * (self.alpha_coe(logps, grad=False) - qvalues).mean() | |
a_loss.backward() | |
def backward_alpha(self, obs): | |
_, logps = self.forward(obs, grad=False) | |
loss_alpha = -(self.alpha_coe(logps + self.tar_ent)).mean() | |
loss_alpha.backward() | |
pass | |
def zero_grads(self): | |
self.optimiser.zero_grad() | |
self.alpha_optimiser.zero_grad() | |
def grad_step(self): | |
self.optimiser.step() | |
self.alpha_optimiser.step() | |
def get_nn_arch_str(self): | |
return str(self.net) + '\n' | |
class SoftDoubleClipCriticQ: | |
def __init__(self, nn_constructor, gamma=0.99, tau=0.005): | |
self.net1 = nn_constructor() | |
self.net2 = nn_constructor() | |
self.tar_net1 = nn_constructor() | |
self.tar_net2 = nn_constructor() | |
self.tar_net1.load_state_dict(self.net1.state_dict()) | |
self.tar_net2.load_state_dict(self.net2.state_dict()) | |
self.opt1 = Adam(self.net1.parameters(), 3e-4) | |
self.opt2 = Adam(self.net2.parameters(), 3e-4) | |
self.device = 'cpu' | |
self.gamma = gamma | |
self.tau = tau | |
def to(self, device): | |
self.net1.to(device) | |
self.net2.to(device) | |
self.tar_net1.to(device) | |
self.tar_net2.to(device) | |
self.device = device | |
def forward(self, obs, acts, grad=True, tar=False): | |
def foo(): | |
if tar: | |
q1 = self.tar_net1(obs, acts) | |
q2 = self.tar_net2(obs, acts) | |
else: | |
q1 = self.net1(obs, acts) | |
q2 = self.net2(obs, acts) | |
return torch.minimum(q1, q2) | |
if grad: | |
return foo() | |
with torch.no_grad(): | |
return foo() | |
def compute_target(self, actor, rews, ops): | |
aps, logpi_aps = actor.forward(ops, grad=False) | |
qps = self.forward(ops, aps, tar=True, grad=False) | |
y = rews + self.gamma * (qps - actor.alpha_coe(logpi_aps, False)) | |
return y | |
def backward_mse(self, actor, obs, acts, rews, ops): | |
y = self.compute_target(actor, rews, ops) | |
loss1 = F.mse_loss(self.net1(obs, acts), y) | |
loss2 = F.mse_loss(self.net2(obs, acts), y) | |
loss1.backward() | |
loss2.backward() | |
def update_tarnet(self): | |
polyak_update(self.net1.parameters(), self.tar_net1.parameters(), self.tau) | |
polyak_update(self.net2.parameters(), self.tar_net2.parameters(), self.tau) | |
def zero_grads(self): | |
self.opt1.zero_grad() | |
self.opt2.zero_grad() | |
def grad_step(self): | |
self.opt1.step() | |
self.opt2.step() | |
def get_nn_arch_str(self): | |
return str(self.net1) + '\n' + str(self.net2) + '\n' | |
class MERegMixSoftActor(SoftActor): | |
def __init__(self, net_constructor, me_reg, tar_ent=None): | |
super(MERegMixSoftActor, self).__init__(net_constructor, tar_ent) | |
self.me_reg = me_reg | |
def forward(self, obs, grad=True, mono=True): | |
if grad: | |
return self.net(obs, mono) | |
with torch.no_grad(): | |
return self.net(obs, mono) | |
def backward_me_reg(self, critic_W, obs): | |
muss, stdss, betas = self.net.get_intermediate(obs) | |
loss1 = -torch.mean(self.me_reg.forward(muss, stdss, betas)) | |
actss, _, _ = esmb_sample(muss, stdss, betas, mono=False) | |
wvaluess = critic_W.forward(obs, actss) | |
loss2 = -(betas * wvaluess).mean() | |
loss = loss1 + loss2 | |
loss.backward() | |
pass | |
def backward_policy(self, critic, obs): | |
actss, logpss, betas = self.forward(obs, mono=False) | |
qvaluess = critic.forward(obs, actss) | |
a_loss = (betas * (self.alpha_coe(logpss, grad=False) - qvaluess)).mean() | |
a_loss.backward() | |
pass | |
def backward_alpha(self, obs): | |
_, logps = self.forward(obs, grad=False) | |
loss_alpha = -(self.alpha_coe(logps + self.tar_ent)).mean() | |
loss_alpha.backward() | |
pass | |
class MERegSoftDoubleClipCriticQ(SoftDoubleClipCriticQ): | |
def forward(self, obs, actss, grad=True, tar=False): | |
def foo(): | |
obss = torch.unsqueeze(obs, dim=1).expand(-1, actss.shape[1], -1) | |
if tar: | |
q1 = self.tar_net1(obss, actss) | |
q2 = self.tar_net2(obss, actss) | |
else: | |
q1 = self.net1(obss, actss) | |
q2 = self.net2(obss, actss) | |
return torch.minimum(q1, q2) | |
if grad: | |
return foo() | |
with torch.no_grad(): | |
return foo() | |
def compute_target(self, actor, rews, ops): | |
apss, logpss, betaps = actor.forward(ops, grad=False, mono=False) | |
qpss = self.forward(ops, apss, tar=True, grad=False) | |
qps = (betaps * (qpss.squeeze() - actor.alpha_coe(logpss, grad=False))).sum(dim=-1) | |
y = rews + self.gamma * qps | |
return y | |
pass | |
class MERegDoubleClipCriticW(MERegSoftDoubleClipCriticQ): | |
def compute_target(self, actor, rews, ops): | |
with torch.no_grad(): | |
mupss, stdpss, betaps = actor.net.get_intermediate(ops) | |
me_regps = actor.me_reg.forward(mupss, stdpss, betaps) | |
apss, *_ = esmb_sample(mupss, stdpss, betaps, mono=False) | |
wpss = self.forward(ops, apss, tar=True, grad=False) | |
y = me_regps + (betaps * wpss).sum(dim=-1) | |
return self.gamma * y | |
pass | |
if __name__ == '__main__': | |
a = torch.tensor([[[1., 1.], [2., 2.], [3., 3.]]], requires_grad=True) | |
b = a.detach().mean(-1) | |
print(a) | |
print(b) | |
pass | |