Spaces:
Running
on
Zero
Running
on
Zero
import torch.nn as nn | |
import torch | |
import tools.utils as utils | |
import agent.dreamer_utils as common | |
from collections import OrderedDict | |
import numpy as np | |
from tools.genrl_utils import * | |
def stop_gradient(x): | |
return x.detach() | |
Module = nn.Module | |
def env_reward(agent, seq): | |
return agent.wm.heads['reward'](seq['feat']).mean | |
class DreamerAgent(Module): | |
def __init__(self, | |
name, cfg, obs_space, act_spec, **kwargs): | |
super().__init__() | |
self.name = name | |
self.cfg = cfg | |
self.cfg.update(**kwargs) | |
self.obs_space = obs_space | |
self.act_spec = act_spec | |
self._use_amp = (cfg.precision == 16) | |
self.device = cfg.device | |
self.act_dim = act_spec.shape[0] | |
self.wm = WorldModel(cfg, obs_space, self.act_dim,) | |
self.instantiate_acting_behavior() | |
self.to(cfg.device) | |
self.requires_grad_(requires_grad=False) | |
def instantiate_acting_behavior(self,): | |
self._acting_behavior = ActorCritic(self.cfg, self.act_spec, self.wm.inp_size).to(self.device) | |
def act(self, obs, meta, step, eval_mode, state): | |
if self.cfg.only_random_actions: | |
return np.random.uniform(-1, 1, self.act_dim,).astype(self.act_spec.dtype), (None, None) | |
obs = {k : torch.as_tensor(np.copy(v), device=self.device).unsqueeze(0) for k, v in obs.items()} | |
if state is None: | |
latent = self.wm.rssm.initial(len(obs['reward'])) | |
action = torch.zeros((len(obs['reward']),) + self.act_spec.shape, device=self.device) | |
else: | |
latent, action = state | |
embed = self.wm.encoder(self.wm.preprocess(obs)) | |
should_sample = (not eval_mode) or (not self.cfg.eval_state_mean) | |
latent, _ = self.wm.rssm.obs_step(latent, action, embed, obs['is_first'], should_sample) | |
feat = self.wm.rssm.get_feat(latent) | |
if eval_mode: | |
actor = self._acting_behavior.actor(feat) | |
try: | |
action = actor.mean | |
except: | |
action = actor._mean | |
else: | |
actor = self._acting_behavior.actor(feat) | |
action = actor.sample() | |
new_state = (latent, action) | |
return action.cpu().numpy()[0], new_state | |
def update_wm(self, data, step): | |
metrics = {} | |
state, outputs, mets = self.wm.update(data, state=None) | |
outputs['is_terminal'] = data['is_terminal'] | |
metrics.update(mets) | |
return state, outputs, metrics | |
def update_acting_behavior(self, state=None, outputs=None, metrics={}, data=None, reward_fn=None): | |
if self.cfg.only_random_actions: | |
return {}, metrics | |
if outputs is not None: | |
post = outputs['post'] | |
is_terminal = outputs['is_terminal'] | |
else: | |
data = self.wm.preprocess(data) | |
embed = self.wm.encoder(data) | |
post, _ = self.wm.rssm.observe( | |
embed, data['action'], data['is_first']) | |
is_terminal = data['is_terminal'] | |
# | |
start = {k: stop_gradient(v) for k,v in post.items()} | |
if reward_fn is None: | |
acting_reward_fn = lambda seq: globals()[self.cfg.acting_reward_fn](self, seq) #.mode() | |
else: | |
acting_reward_fn = lambda seq: reward_fn(self, seq) #.mode() | |
metrics.update(self._acting_behavior.update(self.wm, start, is_terminal, acting_reward_fn)) | |
return start, metrics | |
def update(self, data, step): | |
state, outputs, metrics = self.update_wm(data, step) | |
start, metrics = self.update_acting_behavior(state, outputs, metrics, data) | |
return state, metrics | |
def report(self, data): | |
report = {} | |
data = self.wm.preprocess(data) | |
for key in self.wm.heads['decoder'].cnn_keys: | |
name = key.replace('/', '_') | |
report[f'openl_{name}'] = self.wm.video_pred(data, key) | |
for fn in getattr(self.cfg, 'additional_report_fns', []): | |
call_fn = globals()[fn] | |
additional_report = call_fn(self, data) | |
report.update(additional_report) | |
return report | |
def get_meta_specs(self): | |
return tuple() | |
def init_meta(self): | |
return OrderedDict() | |
def update_meta(self, meta, global_step, time_step, finetune=False): | |
return meta | |
class WorldModel(Module): | |
def __init__(self, config, obs_space, act_dim,): | |
super().__init__() | |
shapes = {k: tuple(v.shape) for k, v in obs_space.items()} | |
self.shapes = shapes | |
self.cfg = config | |
self.device = config.device | |
self.encoder = common.Encoder(shapes, **config.encoder) | |
# Computing embed dim | |
with torch.no_grad(): | |
zeros = {k: torch.zeros( (1,) + v) for k, v in shapes.items()} | |
outs = self.encoder(zeros) | |
embed_dim = outs.shape[1] | |
self.embed_dim = embed_dim | |
self.rssm = common.EnsembleRSSM(**config.rssm, action_dim=act_dim, embed_dim=embed_dim, device=self.device,) | |
self.heads = {} | |
self._use_amp = (config.precision == 16) | |
self.inp_size = self.rssm.get_feat_size() | |
self.decoder_input_fn = getattr(self.rssm, f'get_{config.decoder_inputs}') | |
self.decoder_input_size = getattr(self.rssm, f'get_{config.decoder_inputs}_size')() | |
self.heads['decoder'] = common.Decoder(shapes, **config.decoder, embed_dim=self.decoder_input_size, image_dist=config.image_dist) | |
self.heads['reward'] = common.MLP(self.inp_size, (1,), **config.reward_head) | |
# zero init | |
with torch.no_grad(): | |
for p in self.heads['reward']._out.parameters(): | |
p.data = p.data * 0 | |
# | |
if config.pred_discount: | |
self.heads['discount'] = common.MLP(self.inp_size, (1,), **config.discount_head) | |
for name in config.grad_heads: | |
assert name in self.heads, name | |
self.grad_heads = config.grad_heads | |
self.heads = nn.ModuleDict(self.heads) | |
self.model_opt = common.Optimizer('model', self.parameters(), **config.model_opt, use_amp=self._use_amp) | |
self.e2e_update_fns = {} | |
self.detached_update_fns = {} | |
self.eval() | |
def add_module_to_update(self, name, module, update_fn, detached=False): | |
self.add_module(name, module) | |
if detached: | |
self.detached_update_fns[name] = update_fn | |
else: | |
self.e2e_update_fns[name] = update_fn | |
self.model_opt = common.Optimizer('model', self.parameters(), **self.cfg.model_opt, use_amp=self._use_amp) | |
def update(self, data, state=None): | |
self.train() | |
with common.RequiresGrad(self): | |
with torch.cuda.amp.autocast(enabled=self._use_amp): | |
if getattr(self.cfg, "freeze_decoder", False): | |
self.heads['decoder'].requires_grad_(False) | |
if getattr(self.cfg, "freeze_post", False) or getattr(self.cfg, "freeze_model", False): | |
self.heads['decoder'].requires_grad_(False) | |
self.encoder.requires_grad_(False) | |
# Updating only prior | |
self.grad_heads = [] | |
self.rssm.requires_grad_(False) | |
if not getattr(self.cfg, "freeze_model", False): | |
self.rssm._ensemble_img_out.requires_grad_(True) | |
self.rssm._ensemble_img_dist.requires_grad_(True) | |
model_loss, state, outputs, metrics = self.loss(data, state) | |
model_loss, metrics = self.update_additional_e2e_modules(data, outputs, model_loss, metrics) | |
metrics.update(self.model_opt(model_loss, self.parameters())) | |
if len(self.detached_update_fns) > 0: | |
detached_loss, metrics = self.update_additional_detached_modules(data, outputs, metrics) | |
self.eval() | |
return state, outputs, metrics | |
def update_additional_detached_modules(self, data, outputs, metrics): | |
# additional detached losses | |
detached_loss = 0 | |
for k in self.detached_update_fns: | |
detached_module = getattr(self, k) | |
with common.RequiresGrad(detached_module): | |
with torch.cuda.amp.autocast(enabled=self._use_amp): | |
add_loss, add_metrics = self.detached_update_fns[k](self, k, data, outputs, metrics) | |
metrics.update(add_metrics) | |
opt_metrics = self.model_opt(add_loss, detached_module.parameters()) | |
metrics.update({ f'{k}_{m}' : opt_metrics[m] for m in opt_metrics}) | |
return detached_loss, metrics | |
def update_additional_e2e_modules(self, data, outputs, model_loss, metrics): | |
# additional e2e losses | |
for k in self.e2e_update_fns: | |
add_loss, add_metrics = self.e2e_update_fns[k](self, k, data, outputs, metrics) | |
model_loss += add_loss | |
metrics.update(add_metrics) | |
return model_loss, metrics | |
def observe_data(self, data, state=None): | |
data = self.preprocess(data) | |
embed = self.encoder(data) | |
post, prior = self.rssm.observe( | |
embed, data['action'], data['is_first'], state) | |
kl_loss, kl_value = self.rssm.kl_loss(post, prior, **self.cfg.kl) | |
outs = dict(embed=embed, post=post, prior=prior, is_terminal=data['is_terminal']) | |
return outs, { 'model_kl' : kl_value.mean() } | |
def loss(self, data, state=None): | |
data = self.preprocess(data) | |
embed = self.encoder(data) | |
post, prior = self.rssm.observe( | |
embed, data['action'], data['is_first'], state) | |
kl_loss, kl_value = self.rssm.kl_loss(post, prior, **self.cfg.kl) | |
assert len(kl_loss.shape) == 0 or (len(kl_loss.shape) == 1 and kl_loss.shape[0] == 1), kl_loss.shape | |
likes = {} | |
losses = {'kl': kl_loss} | |
feat = self.rssm.get_feat(post) | |
for name, head in self.heads.items(): | |
grad_head = (name in self.grad_heads) | |
if name == 'decoder': | |
inp = self.decoder_input_fn(post) | |
else: | |
inp = feat | |
inp = inp if grad_head else stop_gradient(inp) | |
out = head(inp) | |
dists = out if isinstance(out, dict) else {name: out} | |
for key, dist in dists.items(): | |
like = dist.log_prob(data[key]) | |
likes[key] = like | |
losses[key] = -like.mean() | |
model_loss = sum( | |
self.cfg.loss_scales.get(k, 1.0) * v for k, v in losses.items()) | |
outs = dict( | |
embed=embed, feat=feat, post=post, | |
prior=prior, likes=likes, kl=kl_value) | |
metrics = {f'{name}_loss': value for name, value in losses.items()} | |
metrics['model_kl'] = kl_value.mean() | |
metrics['prior_ent'] = self.rssm.get_dist(prior).entropy().mean() | |
metrics['post_ent'] = self.rssm.get_dist(post).entropy().mean() | |
last_state = {k: v[:, -1] for k, v in post.items()} | |
return model_loss, last_state, outs, metrics | |
def imagine(self, policy, start, is_terminal, horizon, task_cond=None, eval_policy=False): | |
flatten = lambda x: x.reshape([-1] + list(x.shape[2:])) | |
start = {k: flatten(v) for k, v in start.items()} | |
start['feat'] = self.rssm.get_feat(start) | |
inp = start['feat'] if task_cond is None else torch.cat([start['feat'], task_cond], dim=-1) | |
policy_dist = policy(inp) | |
start['action'] = torch.zeros_like(policy_dist.sample(), device=self.device) #.mode()) | |
seq = {k: [v] for k, v in start.items()} | |
if task_cond is not None: seq['task'] = [task_cond] | |
for _ in range(horizon): | |
inp = seq['feat'][-1] if task_cond is None else torch.cat([seq['feat'][-1], task_cond], dim=-1) | |
policy_dist = policy(stop_gradient(inp)) | |
action = policy_dist.sample() if not eval_policy else policy_dist.mean | |
state = self.rssm.img_step({k: v[-1] for k, v in seq.items()}, action) | |
feat = self.rssm.get_feat(state) | |
for key, value in {**state, 'action': action, 'feat': feat}.items(): | |
seq[key].append(value) | |
if task_cond is not None: seq['task'].append(task_cond) | |
# shape will be (T, B, *DIMS) | |
seq = {k: torch.stack(v, 0) for k, v in seq.items()} | |
if 'discount' in self.heads: | |
disc = self.heads['discount'](seq['feat']).mean() | |
if is_terminal is not None: | |
# Override discount prediction for the first step with the true | |
# discount factor from the replay buffer. | |
true_first = 1.0 - flatten(is_terminal) | |
disc = torch.cat([true_first[None], disc[1:]], 0) | |
else: | |
disc = torch.ones(list(seq['feat'].shape[:-1]) + [1], device=self.device) | |
seq['discount'] = disc * self.cfg.discount | |
# Shift discount factors because they imply whether the following state | |
# will be valid, not whether the current state is valid. | |
seq['weight'] = torch.cumprod(torch.cat([torch.ones_like(disc[:1], device=self.device), disc[:-1]], 0), 0) | |
return seq | |
def preprocess(self, obs): | |
obs = obs.copy() | |
for key, value in obs.items(): | |
if key.startswith('log_'): | |
continue | |
if value.dtype in [np.uint8, torch.uint8]: | |
value = value / 255.0 - 0.5 | |
obs[key] = value | |
obs['reward'] = { | |
'identity': nn.Identity(), | |
'sign': torch.sign, | |
'tanh': torch.tanh, | |
}[self.cfg.clip_rewards](obs['reward']) | |
obs['discount'] = (1.0 - obs['is_terminal'].float()) | |
if len(obs['discount'].shape) < len(obs['reward'].shape): | |
obs['discount'] = obs['discount'].unsqueeze(-1) | |
return obs | |
def video_pred(self, data, key, nvid=8): | |
decoder = self.heads['decoder'] # B, T, C, H, W | |
truth = data[key][:nvid] + 0.5 | |
embed = self.encoder(data) | |
states, _ = self.rssm.observe( | |
embed[:nvid, :5], data['action'][:nvid, :5], data['is_first'][:nvid, :5]) | |
recon = decoder(self.decoder_input_fn(states))[key].mean[:nvid] # mode | |
init = {k: v[:, -1] for k, v in states.items()} | |
prior = self.rssm.imagine(data['action'][:nvid, 5:], init) | |
prior_recon = decoder(self.decoder_input_fn(prior))[key].mean # mode | |
model = torch.clip(torch.cat([recon[:, :5] + 0.5, prior_recon + 0.5], 1), 0, 1) | |
error = (model - truth + 1) / 2 | |
video = torch.cat([truth, model, error], 3) | |
B, T, C, H, W = video.shape | |
return video | |
class ActorCritic(Module): | |
def __init__(self, config, act_spec, feat_size, name=''): | |
super().__init__() | |
self.name = name | |
self.cfg = config | |
self.act_spec = act_spec | |
self._use_amp = (config.precision == 16) | |
self.device = config.device | |
if getattr(self.cfg, 'discrete_actions', False): | |
self.cfg.actor.dist = 'onehot' | |
self.actor_grad = getattr(self.cfg, f'{self.name}_actor_grad'.strip('_')) | |
inp_size = feat_size | |
self.actor = common.MLP(inp_size, act_spec.shape[0], **self.cfg.actor) | |
self.critic = common.MLP(inp_size, (1,), **self.cfg.critic) | |
if self.cfg.slow_target: | |
self._target_critic = common.MLP(inp_size, (1,), **self.cfg.critic) | |
self._updates = 0 # tf.Variable(0, tf.int64) | |
else: | |
self._target_critic = self.critic | |
self.actor_opt = common.Optimizer('actor', self.actor.parameters(), **self.cfg.actor_opt, use_amp=self._use_amp) | |
self.critic_opt = common.Optimizer('critic', self.critic.parameters(), **self.cfg.critic_opt, use_amp=self._use_amp) | |
if self.cfg.reward_ema: | |
# register ema_vals to nn.Module for enabling torch.save and torch.load | |
self.register_buffer("ema_vals", torch.zeros((2,)).to(self.device)) | |
self.reward_ema = common.RewardEMA(device=self.device) | |
self.rewnorm = common.StreamNorm(momentum=1, scale=1.0, device=self.device) | |
else: | |
self.rewnorm = common.StreamNorm(**self.cfg.reward_norm, device=self.device) | |
# zero init | |
with torch.no_grad(): | |
for p in self.critic._out.parameters(): | |
p.data = p.data * 0 | |
# hard copy critic initial params | |
for s, d in zip(self.critic.parameters(), self._target_critic.parameters()): | |
d.data = s.data | |
# | |
def update(self, world_model, start, is_terminal, reward_fn): | |
metrics = {} | |
hor = self.cfg.imag_horizon | |
# The weights are is_terminal flags for the imagination start states. | |
# Technically, they should multiply the losses from the second trajectory | |
# step onwards, which is the first imagined step. However, we are not | |
# training the action that led into the first step anyway, so we can use | |
# them to scale the whole sequence. | |
with common.RequiresGrad(self.actor): | |
with torch.cuda.amp.autocast(enabled=self._use_amp): | |
seq = world_model.imagine(self.actor, start, is_terminal, hor) | |
reward = reward_fn(seq) | |
seq['reward'], mets1 = self.rewnorm(reward) | |
mets1 = {f'reward_{k}': v for k, v in mets1.items()} | |
target, mets2, baseline = self.target(seq) | |
actor_loss, mets3 = self.actor_loss(seq, target, baseline) | |
metrics.update(self.actor_opt(actor_loss, self.actor.parameters())) | |
with common.RequiresGrad(self.critic): | |
with torch.cuda.amp.autocast(enabled=self._use_amp): | |
seq = {k: stop_gradient(v) for k,v in seq.items()} | |
critic_loss, mets4 = self.critic_loss(seq, target) | |
metrics.update(self.critic_opt(critic_loss, self.critic.parameters())) | |
metrics.update(**mets1, **mets2, **mets3, **mets4) | |
self.update_slow_target() # Variables exist after first forward pass. | |
return { f'{self.name}_{k}'.strip('_') : v for k,v in metrics.items() } | |
def actor_loss(self, seq, target, baseline): #, step): | |
# Two state-actions are lost at the end of the trajectory, one for the boostrap | |
# value prediction and one because the corresponding action does not lead | |
# anywhere anymore. One target is lost at the start of the trajectory | |
# because the initial state comes from the replay buffer. | |
policy = self.actor(stop_gradient(seq['feat'][:-2])) # actions are the ones in [1:-1] | |
metrics = {} | |
if self.cfg.reward_ema: | |
offset, scale = self.reward_ema(target, self.ema_vals) | |
normed_target = (target - offset) / scale | |
normed_baseline = (baseline - offset) / scale | |
# adv = normed_target - normed_baseline | |
metrics['normed_target_mean'] = normed_target.mean() | |
metrics['normed_target_std'] = normed_target.std() | |
metrics["reward_ema_005"] = self.ema_vals[0] | |
metrics["reward_ema_095"] = self.ema_vals[1] | |
else: | |
normed_target = target | |
normed_baseline = baseline | |
if self.actor_grad == 'dynamics': | |
objective = normed_target[1:] | |
elif self.actor_grad == 'reinforce': | |
advantage = normed_target[1:] - normed_baseline[1:] | |
objective = policy.log_prob(stop_gradient(seq['action'][1:-1]))[:,:,None] * advantage | |
else: | |
raise NotImplementedError(self.actor_grad) | |
ent = policy.entropy()[:,:,None] | |
ent_scale = self.cfg.actor_ent | |
objective += ent_scale * ent | |
metrics['actor_ent'] = ent.mean() | |
metrics['actor_ent_scale'] = ent_scale | |
weight = stop_gradient(seq['weight']) | |
actor_loss = -(weight[:-2] * objective).mean() | |
return actor_loss, metrics | |
def critic_loss(self, seq, target): | |
feat = seq['feat'][:-1] | |
target = stop_gradient(target) | |
weight = stop_gradient(seq['weight']) | |
dist = self.critic(feat) | |
critic_loss = -(dist.log_prob(target)[:,:,None] * weight[:-1]).mean() | |
metrics = {'critic': dist.mean.mean() } | |
return critic_loss, metrics | |
def target(self, seq): | |
reward = seq['reward'] | |
disc = seq['discount'] | |
value = self._target_critic(seq['feat']).mean | |
# Skipping last time step because it is used for bootstrapping. | |
target = common.lambda_return( | |
reward[:-1], value[:-1], disc[:-1], | |
bootstrap=value[-1], | |
lambda_=self.cfg.discount_lambda, | |
axis=0) | |
metrics = {} | |
metrics['critic_slow'] = value.mean() | |
metrics['critic_target'] = target.mean() | |
return target, metrics, value[:-1] | |
def update_slow_target(self): | |
if self.cfg.slow_target: | |
if self._updates % self.cfg.slow_target_update == 0: | |
mix = 1.0 if self._updates == 0 else float( | |
self.cfg.slow_target_fraction) | |
for s, d in zip(self.critic.parameters(), self._target_critic.parameters()): | |
d.data = mix * s.data + (1 - mix) * d.data | |
self._updates += 1 |