|
import math |
|
import numpy as np |
|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torch import distributions as torchd |
|
|
|
from ding.torch_utils.network.dreamer import weight_init, uniform_weight_init, static_scan, \ |
|
OneHotDist, ContDist, SymlogDist, DreamerLayerNorm |
|
|
|
|
|
class RSSM(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
stoch=30, |
|
deter=200, |
|
hidden=200, |
|
layers_input=1, |
|
layers_output=1, |
|
rec_depth=1, |
|
shared=False, |
|
discrete=False, |
|
act=nn.ELU, |
|
norm=nn.LayerNorm, |
|
mean_act="none", |
|
std_act="softplus", |
|
temp_post=True, |
|
min_std=0.1, |
|
cell="gru", |
|
unimix_ratio=0.01, |
|
num_actions=None, |
|
embed=None, |
|
device=None, |
|
): |
|
super(RSSM, self).__init__() |
|
self._stoch = stoch |
|
self._deter = deter |
|
self._hidden = hidden |
|
self._min_std = min_std |
|
self._layers_input = layers_input |
|
self._layers_output = layers_output |
|
self._rec_depth = rec_depth |
|
self._shared = shared |
|
self._discrete = discrete |
|
self._act = act |
|
self._norm = norm |
|
self._mean_act = mean_act |
|
self._std_act = std_act |
|
self._temp_post = temp_post |
|
self._unimix_ratio = unimix_ratio |
|
self._embed = embed |
|
self._device = device |
|
|
|
inp_layers = [] |
|
if self._discrete: |
|
inp_dim = self._stoch * self._discrete + num_actions |
|
else: |
|
inp_dim = self._stoch + num_actions |
|
if self._shared: |
|
inp_dim += self._embed |
|
for i in range(self._layers_input): |
|
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) |
|
inp_layers.append(self._norm(self._hidden, eps=1e-03)) |
|
inp_layers.append(self._act()) |
|
if i == 0: |
|
inp_dim = self._hidden |
|
self._inp_layers = nn.Sequential(*inp_layers) |
|
self._inp_layers.apply(weight_init) |
|
|
|
if cell == "gru": |
|
self._cell = GRUCell(self._hidden, self._deter) |
|
self._cell.apply(weight_init) |
|
elif cell == "gru_layer_norm": |
|
self._cell = GRUCell(self._hidden, self._deter, norm=True) |
|
self._cell.apply(weight_init) |
|
else: |
|
raise NotImplementedError(cell) |
|
|
|
img_out_layers = [] |
|
inp_dim = self._deter |
|
for i in range(self._layers_output): |
|
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) |
|
img_out_layers.append(self._norm(self._hidden, eps=1e-03)) |
|
img_out_layers.append(self._act()) |
|
if i == 0: |
|
inp_dim = self._hidden |
|
self._img_out_layers = nn.Sequential(*img_out_layers) |
|
self._img_out_layers.apply(weight_init) |
|
|
|
obs_out_layers = [] |
|
if self._temp_post: |
|
inp_dim = self._deter + self._embed |
|
else: |
|
inp_dim = self._embed |
|
for i in range(self._layers_output): |
|
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) |
|
obs_out_layers.append(self._norm(self._hidden, eps=1e-03)) |
|
obs_out_layers.append(self._act()) |
|
if i == 0: |
|
inp_dim = self._hidden |
|
self._obs_out_layers = nn.Sequential(*obs_out_layers) |
|
self._obs_out_layers.apply(weight_init) |
|
|
|
if self._discrete: |
|
self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete) |
|
self._ims_stat_layer.apply(weight_init) |
|
self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete) |
|
self._obs_stat_layer.apply(weight_init) |
|
else: |
|
self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) |
|
self._ims_stat_layer.apply(weight_init) |
|
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) |
|
self._obs_stat_layer.apply(weight_init) |
|
|
|
def initial(self, batch_size): |
|
deter = torch.zeros(batch_size, self._deter).to(self._device) |
|
if self._discrete: |
|
state = dict( |
|
logit=torch.zeros([batch_size, self._stoch, self._discrete]).to(self._device), |
|
stoch=torch.zeros([batch_size, self._stoch, self._discrete]).to(self._device), |
|
deter=deter, |
|
) |
|
else: |
|
state = dict( |
|
mean=torch.zeros([batch_size, self._stoch]).to(self._device), |
|
std=torch.zeros([batch_size, self._stoch]).to(self._device), |
|
stoch=torch.zeros([batch_size, self._stoch]).to(self._device), |
|
deter=deter, |
|
) |
|
return state |
|
|
|
def observe(self, embed, action, state=None): |
|
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) |
|
if state is None: |
|
state = self.initial(action.shape[0]) |
|
|
|
embed, action = swap(embed), swap(action) |
|
post, prior = static_scan( |
|
lambda prev_state, prev_act, embed: self.obs_step(prev_state[0], prev_act, embed), |
|
(action, embed), |
|
(state, state), |
|
) |
|
|
|
|
|
post = {k: swap(v) for k, v in post.items()} |
|
prior = {k: swap(v) for k, v in prior.items()} |
|
return post, prior |
|
|
|
def imagine(self, action, state=None): |
|
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) |
|
if state is None: |
|
state = self.initial(action.shape[0]) |
|
assert isinstance(state, dict), state |
|
action = action |
|
action = swap(action) |
|
prior = static_scan(self.img_step, [action], state) |
|
prior = prior[0] |
|
prior = {k: swap(v) for k, v in prior.items()} |
|
return prior |
|
|
|
def get_feat(self, state): |
|
stoch = state["stoch"] |
|
if self._discrete: |
|
shape = list(stoch.shape[:-2]) + [self._stoch * self._discrete] |
|
stoch = stoch.reshape(shape) |
|
return torch.cat([stoch, state["deter"]], -1) |
|
|
|
def get_dist(self, state, dtype=None): |
|
if self._discrete: |
|
logit = state["logit"] |
|
dist = torchd.independent.Independent(OneHotDist(logit, unimix_ratio=self._unimix_ratio), 1) |
|
else: |
|
mean, std = state["mean"], state["std"] |
|
dist = ContDist(torchd.independent.Independent(torchd.normal.Normal(mean, std), 1)) |
|
return dist |
|
|
|
def obs_step(self, prev_state, prev_action, embed, sample=True): |
|
|
|
|
|
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() |
|
prior = self.img_step(prev_state, prev_action, None, sample) |
|
if self._shared: |
|
post = self.img_step(prev_state, prev_action, embed, sample) |
|
else: |
|
if self._temp_post: |
|
x = torch.cat([prior["deter"], embed], -1) |
|
else: |
|
x = embed |
|
|
|
x = self._obs_out_layers(x) |
|
|
|
stats = self._suff_stats_layer("obs", x) |
|
if sample: |
|
stoch = self.get_dist(stats).sample() |
|
else: |
|
stoch = self.get_dist(stats).mode() |
|
post = {"stoch": stoch, "deter": prior["deter"], **stats} |
|
return post, prior |
|
|
|
|
|
def img_step(self, prev_state, prev_action, embed=None, sample=True): |
|
|
|
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() |
|
prev_stoch = prev_state["stoch"] |
|
if self._discrete: |
|
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete] |
|
|
|
prev_stoch = prev_stoch.reshape(shape) |
|
if self._shared: |
|
if embed is None: |
|
shape = list(prev_action.shape[:-1]) + [self._embed] |
|
embed = torch.zeros(shape) |
|
|
|
x = torch.cat([prev_stoch, prev_action, embed], -1) |
|
else: |
|
x = torch.cat([prev_stoch, prev_action], -1) |
|
|
|
x = self._inp_layers(x) |
|
for _ in range(self._rec_depth): |
|
deter = prev_state["deter"] |
|
|
|
x, deter = self._cell(x, [deter]) |
|
deter = deter[0] |
|
|
|
x = self._img_out_layers(x) |
|
|
|
stats = self._suff_stats_layer("ims", x) |
|
if sample: |
|
stoch = self.get_dist(stats).sample() |
|
else: |
|
stoch = self.get_dist(stats).mode() |
|
prior = {"stoch": stoch, "deter": deter, **stats} |
|
return prior |
|
|
|
def _suff_stats_layer(self, name, x): |
|
if self._discrete: |
|
if name == "ims": |
|
x = self._ims_stat_layer(x) |
|
elif name == "obs": |
|
x = self._obs_stat_layer(x) |
|
else: |
|
raise NotImplementedError |
|
logit = x.reshape(list(x.shape[:-1]) + [self._stoch, self._discrete]) |
|
return {"logit": logit} |
|
else: |
|
if name == "ims": |
|
x = self._ims_stat_layer(x) |
|
elif name == "obs": |
|
x = self._obs_stat_layer(x) |
|
else: |
|
raise NotImplementedError |
|
mean, std = torch.split(x, [self._stoch] * 2, -1) |
|
mean = { |
|
"none": lambda: mean, |
|
"tanh5": lambda: 5.0 * torch.tanh(mean / 5.0), |
|
}[self._mean_act]() |
|
std = { |
|
"softplus": lambda: torch.softplus(std), |
|
"abs": lambda: torch.abs(std + 1), |
|
"sigmoid": lambda: torch.sigmoid(std), |
|
"sigmoid2": lambda: 2 * torch.sigmoid(std / 2), |
|
}[self._std_act]() |
|
std = std + self._min_std |
|
return {"mean": mean, "std": std} |
|
|
|
def kl_loss(self, post, prior, forward, free, lscale, rscale): |
|
kld = torchd.kl.kl_divergence |
|
dist = lambda x: self.get_dist(x) |
|
sg = lambda x: {k: v.detach() for k, v in x.items()} |
|
|
|
lhs, rhs = (prior, post) if forward else (post, prior) |
|
|
|
|
|
value_lhs = value = kld( |
|
dist(lhs) if self._discrete else dist(lhs)._dist, |
|
dist(sg(rhs)) if self._discrete else dist(sg(rhs))._dist, |
|
) |
|
|
|
value_rhs = kld( |
|
dist(sg(lhs)) if self._discrete else dist(sg(lhs))._dist, |
|
dist(rhs) if self._discrete else dist(rhs)._dist, |
|
) |
|
loss_lhs = torch.clip(torch.mean(value_lhs), min=free) |
|
loss_rhs = torch.clip(torch.mean(value_rhs), min=free) |
|
loss = lscale * loss_lhs + rscale * loss_rhs |
|
|
|
return loss, value, loss_lhs, loss_rhs |
|
|
|
|
|
class ConvDecoder(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
inp_depth, |
|
depth=32, |
|
act=nn.ELU, |
|
norm=nn.LayerNorm, |
|
shape=(3, 64, 64), |
|
kernels=(3, 3, 3, 3), |
|
outscale=1.0, |
|
): |
|
super(ConvDecoder, self).__init__() |
|
self._inp_depth = inp_depth |
|
self._act = act |
|
self._norm = norm |
|
self._depth = depth |
|
self._shape = shape |
|
self._kernels = kernels |
|
self._embed_size = ((64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1)) |
|
|
|
self._linear_layer = nn.Linear(inp_depth, self._embed_size) |
|
inp_dim = self._embed_size // 16 |
|
|
|
layers = [] |
|
h, w = 4, 4 |
|
for i, kernel in enumerate(self._kernels): |
|
depth = self._embed_size // 16 // (2 ** (i + 1)) |
|
act = self._act |
|
bias = False |
|
initializer = weight_init |
|
if i == len(self._kernels) - 1: |
|
depth = self._shape[0] |
|
act = False |
|
bias = True |
|
norm = False |
|
initializer = uniform_weight_init(outscale) |
|
|
|
if i != 0: |
|
inp_dim = 2 ** (len(self._kernels) - (i - 1) - 2) * self._depth |
|
pad_h, outpad_h = self.calc_same_pad(k=kernel, s=2, d=1) |
|
pad_w, outpad_w = self.calc_same_pad(k=kernel, s=2, d=1) |
|
layers.append( |
|
nn.ConvTranspose2d( |
|
inp_dim, |
|
depth, |
|
kernel, |
|
2, |
|
padding=(pad_h, pad_w), |
|
output_padding=(outpad_h, outpad_w), |
|
bias=bias, |
|
) |
|
) |
|
if norm: |
|
layers.append(DreamerLayerNorm(depth)) |
|
if act: |
|
layers.append(act()) |
|
[m.apply(initializer) for m in layers[-3:]] |
|
h, w = h * 2, w * 2 |
|
|
|
self.layers = nn.Sequential(*layers) |
|
|
|
def calc_same_pad(self, k, s, d): |
|
val = d * (k - 1) - s + 1 |
|
pad = math.ceil(val / 2) |
|
outpad = pad * 2 - val |
|
return pad, outpad |
|
|
|
def __call__(self, features, dtype=None): |
|
x = self._linear_layer(features) |
|
x = x.reshape([-1, 4, 4, self._embed_size // 16]) |
|
x = x.permute(0, 3, 1, 2) |
|
x = self.layers(x) |
|
mean = x.reshape(list(features.shape[:-1]) + self._shape) |
|
|
|
return SymlogDist(mean) |
|
|
|
|
|
class GRUCell(nn.Module): |
|
|
|
def __init__(self, inp_size, size, norm=False, act=torch.tanh, update_bias=-1): |
|
super(GRUCell, self).__init__() |
|
self._inp_size = inp_size |
|
self._size = size |
|
self._act = act |
|
self._norm = norm |
|
self._update_bias = update_bias |
|
self._layer = nn.Linear(inp_size + size, 3 * size, bias=False) |
|
if norm: |
|
self._norm = nn.LayerNorm(3 * size, eps=1e-03) |
|
|
|
@property |
|
def state_size(self): |
|
return self._size |
|
|
|
def forward(self, inputs, state): |
|
state = state[0] |
|
parts = self._layer(torch.cat([inputs, state], -1)) |
|
if self._norm: |
|
parts = self._norm(parts) |
|
reset, cand, update = torch.split(parts, [self._size] * 3, -1) |
|
reset = torch.sigmoid(reset) |
|
cand = self._act(reset * cand) |
|
update = torch.sigmoid(update + self._update_bias) |
|
output = update * cand + (1 - update) * state |
|
return output, [output] |
|
|