import torch.nn as nn import torch import tools.utils as utils import agent.dreamer_utils as common from collections import OrderedDict import numpy as np from tools.genrl_utils import * def stop_gradient(x): return x.detach() Module = nn.Module def env_reward(agent, seq): return agent.wm.heads['reward'](seq['feat']).mean class DreamerAgent(Module): def __init__(self, name, cfg, obs_space, act_spec, **kwargs): super().__init__() self.name = name self.cfg = cfg self.cfg.update(**kwargs) self.obs_space = obs_space self.act_spec = act_spec self._use_amp = (cfg.precision == 16) self.device = cfg.device self.act_dim = act_spec.shape[0] self.wm = WorldModel(cfg, obs_space, self.act_dim,) self.instantiate_acting_behavior() self.to(cfg.device) self.requires_grad_(requires_grad=False) def instantiate_acting_behavior(self,): self._acting_behavior = ActorCritic(self.cfg, self.act_spec, self.wm.inp_size).to(self.device) def act(self, obs, meta, step, eval_mode, state): if self.cfg.only_random_actions: return np.random.uniform(-1, 1, self.act_dim,).astype(self.act_spec.dtype), (None, None) obs = {k : torch.as_tensor(np.copy(v), device=self.device).unsqueeze(0) for k, v in obs.items()} if state is None: latent = self.wm.rssm.initial(len(obs['reward'])) action = torch.zeros((len(obs['reward']),) + self.act_spec.shape, device=self.device) else: latent, action = state embed = self.wm.encoder(self.wm.preprocess(obs)) should_sample = (not eval_mode) or (not self.cfg.eval_state_mean) latent, _ = self.wm.rssm.obs_step(latent, action, embed, obs['is_first'], should_sample) feat = self.wm.rssm.get_feat(latent) if eval_mode: actor = self._acting_behavior.actor(feat) try: action = actor.mean except: action = actor._mean else: actor = self._acting_behavior.actor(feat) action = actor.sample() new_state = (latent, action) return action.cpu().numpy()[0], new_state def update_wm(self, data, step): metrics = {} state, outputs, mets = self.wm.update(data, state=None) outputs['is_terminal'] = data['is_terminal'] metrics.update(mets) return state, outputs, metrics def update_acting_behavior(self, state=None, outputs=None, metrics={}, data=None, reward_fn=None): if self.cfg.only_random_actions: return {}, metrics if outputs is not None: post = outputs['post'] is_terminal = outputs['is_terminal'] else: data = self.wm.preprocess(data) embed = self.wm.encoder(data) post, _ = self.wm.rssm.observe( embed, data['action'], data['is_first']) is_terminal = data['is_terminal'] # start = {k: stop_gradient(v) for k,v in post.items()} if reward_fn is None: acting_reward_fn = lambda seq: globals()[self.cfg.acting_reward_fn](self, seq) #.mode() else: acting_reward_fn = lambda seq: reward_fn(self, seq) #.mode() metrics.update(self._acting_behavior.update(self.wm, start, is_terminal, acting_reward_fn)) return start, metrics def update(self, data, step): state, outputs, metrics = self.update_wm(data, step) start, metrics = self.update_acting_behavior(state, outputs, metrics, data) return state, metrics def report(self, data): report = {} data = self.wm.preprocess(data) for key in self.wm.heads['decoder'].cnn_keys: name = key.replace('/', '_') report[f'openl_{name}'] = self.wm.video_pred(data, key) for fn in getattr(self.cfg, 'additional_report_fns', []): call_fn = globals()[fn] additional_report = call_fn(self, data) report.update(additional_report) return report def get_meta_specs(self): return tuple() def init_meta(self): return OrderedDict() def update_meta(self, meta, global_step, time_step, finetune=False): return meta class WorldModel(Module): def __init__(self, config, obs_space, act_dim,): super().__init__() shapes = {k: tuple(v.shape) for k, v in obs_space.items()} self.shapes = shapes self.cfg = config self.device = config.device self.encoder = common.Encoder(shapes, **config.encoder) # Computing embed dim with torch.no_grad(): zeros = {k: torch.zeros( (1,) + v) for k, v in shapes.items()} outs = self.encoder(zeros) embed_dim = outs.shape[1] self.embed_dim = embed_dim self.rssm = common.EnsembleRSSM(**config.rssm, action_dim=act_dim, embed_dim=embed_dim, device=self.device,) self.heads = {} self._use_amp = (config.precision == 16) self.inp_size = self.rssm.get_feat_size() self.decoder_input_fn = getattr(self.rssm, f'get_{config.decoder_inputs}') self.decoder_input_size = getattr(self.rssm, f'get_{config.decoder_inputs}_size')() self.heads['decoder'] = common.Decoder(shapes, **config.decoder, embed_dim=self.decoder_input_size, image_dist=config.image_dist) self.heads['reward'] = common.MLP(self.inp_size, (1,), **config.reward_head) # zero init with torch.no_grad(): for p in self.heads['reward']._out.parameters(): p.data = p.data * 0 # if config.pred_discount: self.heads['discount'] = common.MLP(self.inp_size, (1,), **config.discount_head) for name in config.grad_heads: assert name in self.heads, name self.grad_heads = config.grad_heads self.heads = nn.ModuleDict(self.heads) self.model_opt = common.Optimizer('model', self.parameters(), **config.model_opt, use_amp=self._use_amp) self.e2e_update_fns = {} self.detached_update_fns = {} self.eval() def add_module_to_update(self, name, module, update_fn, detached=False): self.add_module(name, module) if detached: self.detached_update_fns[name] = update_fn else: self.e2e_update_fns[name] = update_fn self.model_opt = common.Optimizer('model', self.parameters(), **self.cfg.model_opt, use_amp=self._use_amp) def update(self, data, state=None): self.train() with common.RequiresGrad(self): with torch.cuda.amp.autocast(enabled=self._use_amp): if getattr(self.cfg, "freeze_decoder", False): self.heads['decoder'].requires_grad_(False) if getattr(self.cfg, "freeze_post", False) or getattr(self.cfg, "freeze_model", False): self.heads['decoder'].requires_grad_(False) self.encoder.requires_grad_(False) # Updating only prior self.grad_heads = [] self.rssm.requires_grad_(False) if not getattr(self.cfg, "freeze_model", False): self.rssm._ensemble_img_out.requires_grad_(True) self.rssm._ensemble_img_dist.requires_grad_(True) model_loss, state, outputs, metrics = self.loss(data, state) model_loss, metrics = self.update_additional_e2e_modules(data, outputs, model_loss, metrics) metrics.update(self.model_opt(model_loss, self.parameters())) if len(self.detached_update_fns) > 0: detached_loss, metrics = self.update_additional_detached_modules(data, outputs, metrics) self.eval() return state, outputs, metrics def update_additional_detached_modules(self, data, outputs, metrics): # additional detached losses detached_loss = 0 for k in self.detached_update_fns: detached_module = getattr(self, k) with common.RequiresGrad(detached_module): with torch.cuda.amp.autocast(enabled=self._use_amp): add_loss, add_metrics = self.detached_update_fns[k](self, k, data, outputs, metrics) metrics.update(add_metrics) opt_metrics = self.model_opt(add_loss, detached_module.parameters()) metrics.update({ f'{k}_{m}' : opt_metrics[m] for m in opt_metrics}) return detached_loss, metrics def update_additional_e2e_modules(self, data, outputs, model_loss, metrics): # additional e2e losses for k in self.e2e_update_fns: add_loss, add_metrics = self.e2e_update_fns[k](self, k, data, outputs, metrics) model_loss += add_loss metrics.update(add_metrics) return model_loss, metrics def observe_data(self, data, state=None): data = self.preprocess(data) embed = self.encoder(data) post, prior = self.rssm.observe( embed, data['action'], data['is_first'], state) kl_loss, kl_value = self.rssm.kl_loss(post, prior, **self.cfg.kl) outs = dict(embed=embed, post=post, prior=prior, is_terminal=data['is_terminal']) return outs, { 'model_kl' : kl_value.mean() } def loss(self, data, state=None): data = self.preprocess(data) embed = self.encoder(data) post, prior = self.rssm.observe( embed, data['action'], data['is_first'], state) kl_loss, kl_value = self.rssm.kl_loss(post, prior, **self.cfg.kl) assert len(kl_loss.shape) == 0 or (len(kl_loss.shape) == 1 and kl_loss.shape[0] == 1), kl_loss.shape likes = {} losses = {'kl': kl_loss} feat = self.rssm.get_feat(post) for name, head in self.heads.items(): grad_head = (name in self.grad_heads) if name == 'decoder': inp = self.decoder_input_fn(post) else: inp = feat inp = inp if grad_head else stop_gradient(inp) out = head(inp) dists = out if isinstance(out, dict) else {name: out} for key, dist in dists.items(): like = dist.log_prob(data[key]) likes[key] = like losses[key] = -like.mean() model_loss = sum( self.cfg.loss_scales.get(k, 1.0) * v for k, v in losses.items()) outs = dict( embed=embed, feat=feat, post=post, prior=prior, likes=likes, kl=kl_value) metrics = {f'{name}_loss': value for name, value in losses.items()} metrics['model_kl'] = kl_value.mean() metrics['prior_ent'] = self.rssm.get_dist(prior).entropy().mean() metrics['post_ent'] = self.rssm.get_dist(post).entropy().mean() last_state = {k: v[:, -1] for k, v in post.items()} return model_loss, last_state, outs, metrics def imagine(self, policy, start, is_terminal, horizon, task_cond=None, eval_policy=False): flatten = lambda x: x.reshape([-1] + list(x.shape[2:])) start = {k: flatten(v) for k, v in start.items()} start['feat'] = self.rssm.get_feat(start) inp = start['feat'] if task_cond is None else torch.cat([start['feat'], task_cond], dim=-1) policy_dist = policy(inp) start['action'] = torch.zeros_like(policy_dist.sample(), device=self.device) #.mode()) seq = {k: [v] for k, v in start.items()} if task_cond is not None: seq['task'] = [task_cond] for _ in range(horizon): inp = seq['feat'][-1] if task_cond is None else torch.cat([seq['feat'][-1], task_cond], dim=-1) policy_dist = policy(stop_gradient(inp)) action = policy_dist.sample() if not eval_policy else policy_dist.mean state = self.rssm.img_step({k: v[-1] for k, v in seq.items()}, action) feat = self.rssm.get_feat(state) for key, value in {**state, 'action': action, 'feat': feat}.items(): seq[key].append(value) if task_cond is not None: seq['task'].append(task_cond) # shape will be (T, B, *DIMS) seq = {k: torch.stack(v, 0) for k, v in seq.items()} if 'discount' in self.heads: disc = self.heads['discount'](seq['feat']).mean() if is_terminal is not None: # Override discount prediction for the first step with the true # discount factor from the replay buffer. true_first = 1.0 - flatten(is_terminal) disc = torch.cat([true_first[None], disc[1:]], 0) else: disc = torch.ones(list(seq['feat'].shape[:-1]) + [1], device=self.device) seq['discount'] = disc * self.cfg.discount # Shift discount factors because they imply whether the following state # will be valid, not whether the current state is valid. seq['weight'] = torch.cumprod(torch.cat([torch.ones_like(disc[:1], device=self.device), disc[:-1]], 0), 0) return seq def preprocess(self, obs): obs = obs.copy() for key, value in obs.items(): if key.startswith('log_'): continue if value.dtype in [np.uint8, torch.uint8]: value = value / 255.0 - 0.5 obs[key] = value obs['reward'] = { 'identity': nn.Identity(), 'sign': torch.sign, 'tanh': torch.tanh, }[self.cfg.clip_rewards](obs['reward']) obs['discount'] = (1.0 - obs['is_terminal'].float()) if len(obs['discount'].shape) < len(obs['reward'].shape): obs['discount'] = obs['discount'].unsqueeze(-1) return obs def video_pred(self, data, key, nvid=8): decoder = self.heads['decoder'] # B, T, C, H, W truth = data[key][:nvid] + 0.5 embed = self.encoder(data) states, _ = self.rssm.observe( embed[:nvid, :5], data['action'][:nvid, :5], data['is_first'][:nvid, :5]) recon = decoder(self.decoder_input_fn(states))[key].mean[:nvid] # mode init = {k: v[:, -1] for k, v in states.items()} prior = self.rssm.imagine(data['action'][:nvid, 5:], init) prior_recon = decoder(self.decoder_input_fn(prior))[key].mean # mode model = torch.clip(torch.cat([recon[:, :5] + 0.5, prior_recon + 0.5], 1), 0, 1) error = (model - truth + 1) / 2 video = torch.cat([truth, model, error], 3) B, T, C, H, W = video.shape return video class ActorCritic(Module): def __init__(self, config, act_spec, feat_size, name=''): super().__init__() self.name = name self.cfg = config self.act_spec = act_spec self._use_amp = (config.precision == 16) self.device = config.device if getattr(self.cfg, 'discrete_actions', False): self.cfg.actor.dist = 'onehot' self.actor_grad = getattr(self.cfg, f'{self.name}_actor_grad'.strip('_')) inp_size = feat_size self.actor = common.MLP(inp_size, act_spec.shape[0], **self.cfg.actor) self.critic = common.MLP(inp_size, (1,), **self.cfg.critic) if self.cfg.slow_target: self._target_critic = common.MLP(inp_size, (1,), **self.cfg.critic) self._updates = 0 # tf.Variable(0, tf.int64) else: self._target_critic = self.critic self.actor_opt = common.Optimizer('actor', self.actor.parameters(), **self.cfg.actor_opt, use_amp=self._use_amp) self.critic_opt = common.Optimizer('critic', self.critic.parameters(), **self.cfg.critic_opt, use_amp=self._use_amp) if self.cfg.reward_ema: # register ema_vals to nn.Module for enabling torch.save and torch.load self.register_buffer("ema_vals", torch.zeros((2,)).to(self.device)) self.reward_ema = common.RewardEMA(device=self.device) self.rewnorm = common.StreamNorm(momentum=1, scale=1.0, device=self.device) else: self.rewnorm = common.StreamNorm(**self.cfg.reward_norm, device=self.device) # zero init with torch.no_grad(): for p in self.critic._out.parameters(): p.data = p.data * 0 # hard copy critic initial params for s, d in zip(self.critic.parameters(), self._target_critic.parameters()): d.data = s.data # def update(self, world_model, start, is_terminal, reward_fn): metrics = {} hor = self.cfg.imag_horizon # The weights are is_terminal flags for the imagination start states. # Technically, they should multiply the losses from the second trajectory # step onwards, which is the first imagined step. However, we are not # training the action that led into the first step anyway, so we can use # them to scale the whole sequence. with common.RequiresGrad(self.actor): with torch.cuda.amp.autocast(enabled=self._use_amp): seq = world_model.imagine(self.actor, start, is_terminal, hor) reward = reward_fn(seq) seq['reward'], mets1 = self.rewnorm(reward) mets1 = {f'reward_{k}': v for k, v in mets1.items()} target, mets2, baseline = self.target(seq) actor_loss, mets3 = self.actor_loss(seq, target, baseline) metrics.update(self.actor_opt(actor_loss, self.actor.parameters())) with common.RequiresGrad(self.critic): with torch.cuda.amp.autocast(enabled=self._use_amp): seq = {k: stop_gradient(v) for k,v in seq.items()} critic_loss, mets4 = self.critic_loss(seq, target) metrics.update(self.critic_opt(critic_loss, self.critic.parameters())) metrics.update(**mets1, **mets2, **mets3, **mets4) self.update_slow_target() # Variables exist after first forward pass. return { f'{self.name}_{k}'.strip('_') : v for k,v in metrics.items() } def actor_loss(self, seq, target, baseline): #, step): # Two state-actions are lost at the end of the trajectory, one for the boostrap # value prediction and one because the corresponding action does not lead # anywhere anymore. One target is lost at the start of the trajectory # because the initial state comes from the replay buffer. policy = self.actor(stop_gradient(seq['feat'][:-2])) # actions are the ones in [1:-1] metrics = {} if self.cfg.reward_ema: offset, scale = self.reward_ema(target, self.ema_vals) normed_target = (target - offset) / scale normed_baseline = (baseline - offset) / scale # adv = normed_target - normed_baseline metrics['normed_target_mean'] = normed_target.mean() metrics['normed_target_std'] = normed_target.std() metrics["reward_ema_005"] = self.ema_vals[0] metrics["reward_ema_095"] = self.ema_vals[1] else: normed_target = target normed_baseline = baseline if self.actor_grad == 'dynamics': objective = normed_target[1:] elif self.actor_grad == 'reinforce': advantage = normed_target[1:] - normed_baseline[1:] objective = policy.log_prob(stop_gradient(seq['action'][1:-1]))[:,:,None] * advantage else: raise NotImplementedError(self.actor_grad) ent = policy.entropy()[:,:,None] ent_scale = self.cfg.actor_ent objective += ent_scale * ent metrics['actor_ent'] = ent.mean() metrics['actor_ent_scale'] = ent_scale weight = stop_gradient(seq['weight']) actor_loss = -(weight[:-2] * objective).mean() return actor_loss, metrics def critic_loss(self, seq, target): feat = seq['feat'][:-1] target = stop_gradient(target) weight = stop_gradient(seq['weight']) dist = self.critic(feat) critic_loss = -(dist.log_prob(target)[:,:,None] * weight[:-1]).mean() metrics = {'critic': dist.mean.mean() } return critic_loss, metrics def target(self, seq): reward = seq['reward'] disc = seq['discount'] value = self._target_critic(seq['feat']).mean # Skipping last time step because it is used for bootstrapping. target = common.lambda_return( reward[:-1], value[:-1], disc[:-1], bootstrap=value[-1], lambda_=self.cfg.discount_lambda, axis=0) metrics = {} metrics['critic_slow'] = value.mean() metrics['critic_target'] = target.mean() return target, metrics, value[:-1] def update_slow_target(self): if self.cfg.slow_target: if self._updates % self.cfg.slow_target_update == 0: mix = 1.0 if self._updates == 0 else float( self.cfg.slow_target_fraction) for s, d in zip(self.critic.parameters(), self._target_critic.parameters()): d.data = mix * s.data + (1 - mix) * d.data self._updates += 1