import math import numpy as np import torch from torch import nn import torch.nn.functional as F from torch import distributions as torchd from ding.torch_utils.network.dreamer import weight_init, uniform_weight_init, static_scan, \ OneHotDist, ContDist, SymlogDist, DreamerLayerNorm class RSSM(nn.Module): def __init__( self, stoch=30, deter=200, hidden=200, layers_input=1, layers_output=1, rec_depth=1, shared=False, discrete=False, act=nn.ELU, norm=nn.LayerNorm, mean_act="none", std_act="softplus", temp_post=True, min_std=0.1, cell="gru", unimix_ratio=0.01, num_actions=None, embed=None, device=None, ): super(RSSM, self).__init__() self._stoch = stoch self._deter = deter self._hidden = hidden self._min_std = min_std self._layers_input = layers_input self._layers_output = layers_output self._rec_depth = rec_depth self._shared = shared self._discrete = discrete self._act = act self._norm = norm self._mean_act = mean_act self._std_act = std_act self._temp_post = temp_post self._unimix_ratio = unimix_ratio self._embed = embed self._device = device inp_layers = [] if self._discrete: inp_dim = self._stoch * self._discrete + num_actions else: inp_dim = self._stoch + num_actions if self._shared: inp_dim += self._embed for i in range(self._layers_input): inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) inp_layers.append(self._norm(self._hidden, eps=1e-03)) inp_layers.append(self._act()) if i == 0: inp_dim = self._hidden self._inp_layers = nn.Sequential(*inp_layers) self._inp_layers.apply(weight_init) if cell == "gru": self._cell = GRUCell(self._hidden, self._deter) self._cell.apply(weight_init) elif cell == "gru_layer_norm": self._cell = GRUCell(self._hidden, self._deter, norm=True) self._cell.apply(weight_init) else: raise NotImplementedError(cell) img_out_layers = [] inp_dim = self._deter for i in range(self._layers_output): img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) img_out_layers.append(self._norm(self._hidden, eps=1e-03)) img_out_layers.append(self._act()) if i == 0: inp_dim = self._hidden self._img_out_layers = nn.Sequential(*img_out_layers) self._img_out_layers.apply(weight_init) obs_out_layers = [] if self._temp_post: inp_dim = self._deter + self._embed else: inp_dim = self._embed for i in range(self._layers_output): obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) obs_out_layers.append(self._norm(self._hidden, eps=1e-03)) obs_out_layers.append(self._act()) if i == 0: inp_dim = self._hidden self._obs_out_layers = nn.Sequential(*obs_out_layers) self._obs_out_layers.apply(weight_init) if self._discrete: self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete) self._ims_stat_layer.apply(weight_init) self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete) self._obs_stat_layer.apply(weight_init) else: self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) self._ims_stat_layer.apply(weight_init) self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) self._obs_stat_layer.apply(weight_init) def initial(self, batch_size): deter = torch.zeros(batch_size, self._deter).to(self._device) if self._discrete: state = dict( logit=torch.zeros([batch_size, self._stoch, self._discrete]).to(self._device), stoch=torch.zeros([batch_size, self._stoch, self._discrete]).to(self._device), deter=deter, ) else: state = dict( mean=torch.zeros([batch_size, self._stoch]).to(self._device), std=torch.zeros([batch_size, self._stoch]).to(self._device), stoch=torch.zeros([batch_size, self._stoch]).to(self._device), deter=deter, ) return state def observe(self, embed, action, state=None): swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) # 交换前两维 if state is None: state = self.initial(action.shape[0]) # {logit, stoch, deter} # (batch, time, ch) -> (time, batch, ch) embed, action = swap(embed), swap(action) post, prior = static_scan( lambda prev_state, prev_act, embed: self.obs_step(prev_state[0], prev_act, embed), (action, embed), (state, state), ) # (time, batch, stoch, discrete_num) -> (batch, time, stoch, discrete_num) post = {k: swap(v) for k, v in post.items()} prior = {k: swap(v) for k, v in prior.items()} return post, prior def imagine(self, action, state=None): swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) if state is None: state = self.initial(action.shape[0]) assert isinstance(state, dict), state action = action action = swap(action) prior = static_scan(self.img_step, [action], state) prior = prior[0] prior = {k: swap(v) for k, v in prior.items()} return prior def get_feat(self, state): stoch = state["stoch"] if self._discrete: shape = list(stoch.shape[:-2]) + [self._stoch * self._discrete] stoch = stoch.reshape(shape) return torch.cat([stoch, state["deter"]], -1) def get_dist(self, state, dtype=None): if self._discrete: logit = state["logit"] dist = torchd.independent.Independent(OneHotDist(logit, unimix_ratio=self._unimix_ratio), 1) else: mean, std = state["mean"], state["std"] dist = ContDist(torchd.independent.Independent(torchd.normal.Normal(mean, std), 1)) return dist def obs_step(self, prev_state, prev_action, embed, sample=True): # if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer) # otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() prior = self.img_step(prev_state, prev_action, None, sample) if self._shared: post = self.img_step(prev_state, prev_action, embed, sample) else: if self._temp_post: x = torch.cat([prior["deter"], embed], -1) else: x = embed # (batch_size, prior_deter + embed) -> (batch_size, hidden) x = self._obs_out_layers(x) # (batch_size, hidden) -> (batch_size, stoch, discrete_num) stats = self._suff_stats_layer("obs", x) if sample: stoch = self.get_dist(stats).sample() else: stoch = self.get_dist(stats).mode() post = {"stoch": stoch, "deter": prior["deter"], **stats} return post, prior # this is used for making future image def img_step(self, prev_state, prev_action, embed=None, sample=True): # (batch, stoch, discrete_num) prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() prev_stoch = prev_state["stoch"] if self._discrete: shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete] # (batch, stoch, discrete_num) -> (batch, stoch * discrete_num) prev_stoch = prev_stoch.reshape(shape) if self._shared: if embed is None: shape = list(prev_action.shape[:-1]) + [self._embed] embed = torch.zeros(shape) # (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action, embed) x = torch.cat([prev_stoch, prev_action, embed], -1) else: x = torch.cat([prev_stoch, prev_action], -1) # (batch, stoch * discrete_num + action, embed) -> (batch, hidden) x = self._inp_layers(x) for _ in range(self._rec_depth): # rec depth is not correctly implemented deter = prev_state["deter"] # (batch, hidden), (batch, deter) -> (batch, deter), (batch, deter) x, deter = self._cell(x, [deter]) deter = deter[0] # Keras wraps the state in a list. # (batch, deter) -> (batch, hidden) x = self._img_out_layers(x) # (batch, hidden) -> (batch_size, stoch, discrete_num) stats = self._suff_stats_layer("ims", x) if sample: stoch = self.get_dist(stats).sample() else: stoch = self.get_dist(stats).mode() prior = {"stoch": stoch, "deter": deter, **stats} # {stoch, deter, logit} return prior def _suff_stats_layer(self, name, x): if self._discrete: if name == "ims": x = self._ims_stat_layer(x) elif name == "obs": x = self._obs_stat_layer(x) else: raise NotImplementedError logit = x.reshape(list(x.shape[:-1]) + [self._stoch, self._discrete]) return {"logit": logit} else: if name == "ims": x = self._ims_stat_layer(x) elif name == "obs": x = self._obs_stat_layer(x) else: raise NotImplementedError mean, std = torch.split(x, [self._stoch] * 2, -1) mean = { "none": lambda: mean, "tanh5": lambda: 5.0 * torch.tanh(mean / 5.0), }[self._mean_act]() std = { "softplus": lambda: torch.softplus(std), "abs": lambda: torch.abs(std + 1), "sigmoid": lambda: torch.sigmoid(std), "sigmoid2": lambda: 2 * torch.sigmoid(std / 2), }[self._std_act]() std = std + self._min_std return {"mean": mean, "std": std} def kl_loss(self, post, prior, forward, free, lscale, rscale): kld = torchd.kl.kl_divergence dist = lambda x: self.get_dist(x) sg = lambda x: {k: v.detach() for k, v in x.items()} # forward == false -> (post, prior) lhs, rhs = (prior, post) if forward else (post, prior) # forward == false -> Lrep value_lhs = value = kld( dist(lhs) if self._discrete else dist(lhs)._dist, dist(sg(rhs)) if self._discrete else dist(sg(rhs))._dist, ) # forward == false -> Ldyn value_rhs = kld( dist(sg(lhs)) if self._discrete else dist(sg(lhs))._dist, dist(rhs) if self._discrete else dist(rhs)._dist, ) loss_lhs = torch.clip(torch.mean(value_lhs), min=free) loss_rhs = torch.clip(torch.mean(value_rhs), min=free) loss = lscale * loss_lhs + rscale * loss_rhs return loss, value, loss_lhs, loss_rhs class ConvDecoder(nn.Module): def __init__( self, inp_depth, # config.dyn_stoch * config.dyn_discrete + config.dyn_deter depth=32, act=nn.ELU, norm=nn.LayerNorm, shape=(3, 64, 64), kernels=(3, 3, 3, 3), outscale=1.0, ): super(ConvDecoder, self).__init__() self._inp_depth = inp_depth self._act = act self._norm = norm self._depth = depth self._shape = shape self._kernels = kernels self._embed_size = ((64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1)) self._linear_layer = nn.Linear(inp_depth, self._embed_size) inp_dim = self._embed_size // 16 # 除以最后的4*4 feature map来得到channel数 layers = [] h, w = 4, 4 for i, kernel in enumerate(self._kernels): depth = self._embed_size // 16 // (2 ** (i + 1)) act = self._act bias = False initializer = weight_init if i == len(self._kernels) - 1: depth = self._shape[0] act = False bias = True norm = False initializer = uniform_weight_init(outscale) if i != 0: inp_dim = 2 ** (len(self._kernels) - (i - 1) - 2) * self._depth pad_h, outpad_h = self.calc_same_pad(k=kernel, s=2, d=1) pad_w, outpad_w = self.calc_same_pad(k=kernel, s=2, d=1) layers.append( nn.ConvTranspose2d( inp_dim, depth, kernel, 2, padding=(pad_h, pad_w), output_padding=(outpad_h, outpad_w), bias=bias, ) ) if norm: layers.append(DreamerLayerNorm(depth)) if act: layers.append(act()) [m.apply(initializer) for m in layers[-3:]] h, w = h * 2, w * 2 self.layers = nn.Sequential(*layers) def calc_same_pad(self, k, s, d): val = d * (k - 1) - s + 1 pad = math.ceil(val / 2) outpad = pad * 2 - val return pad, outpad def __call__(self, features, dtype=None): x = self._linear_layer(features) # feature:[batch, time, stoch*discrete + deter] x = x.reshape([-1, 4, 4, self._embed_size // 16]) x = x.permute(0, 3, 1, 2) x = self.layers(x) mean = x.reshape(list(features.shape[:-1]) + self._shape) #mean = mean.permute(0, 1, 3, 4, 2) return SymlogDist(mean) class GRUCell(nn.Module): def __init__(self, inp_size, size, norm=False, act=torch.tanh, update_bias=-1): super(GRUCell, self).__init__() self._inp_size = inp_size # hidden self._size = size # deter self._act = act self._norm = norm self._update_bias = update_bias self._layer = nn.Linear(inp_size + size, 3 * size, bias=False) if norm: self._norm = nn.LayerNorm(3 * size, eps=1e-03) @property def state_size(self): return self._size def forward(self, inputs, state): state = state[0] # Keras wraps the state in a list. parts = self._layer(torch.cat([inputs, state], -1)) if self._norm: parts = self._norm(parts) reset, cand, update = torch.split(parts, [self._size] * 3, -1) reset = torch.sigmoid(reset) cand = self._act(reset * cand) update = torch.sigmoid(update + self._update_bias) output = update * cand + (1 - update) * state return output, [output]