genrl / agent /dreamer.py
mazpie's picture
Initial commit
2d9a728
raw
history blame
No virus
19.4 kB
import torch.nn as nn
import torch
import tools.utils as utils
import agent.dreamer_utils as common
from collections import OrderedDict
import numpy as np
from tools.genrl_utils import *
def stop_gradient(x):
return x.detach()
Module = nn.Module
def env_reward(agent, seq):
return agent.wm.heads['reward'](seq['feat']).mean
class DreamerAgent(Module):
def __init__(self,
name, cfg, obs_space, act_spec, **kwargs):
super().__init__()
self.name = name
self.cfg = cfg
self.cfg.update(**kwargs)
self.obs_space = obs_space
self.act_spec = act_spec
self._use_amp = (cfg.precision == 16)
self.device = cfg.device
self.act_dim = act_spec.shape[0]
self.wm = WorldModel(cfg, obs_space, self.act_dim,)
self.instantiate_acting_behavior()
self.to(cfg.device)
self.requires_grad_(requires_grad=False)
def instantiate_acting_behavior(self,):
self._acting_behavior = ActorCritic(self.cfg, self.act_spec, self.wm.inp_size).to(self.device)
def act(self, obs, meta, step, eval_mode, state):
if self.cfg.only_random_actions:
return np.random.uniform(-1, 1, self.act_dim,).astype(self.act_spec.dtype), (None, None)
obs = {k : torch.as_tensor(np.copy(v), device=self.device).unsqueeze(0) for k, v in obs.items()}
if state is None:
latent = self.wm.rssm.initial(len(obs['reward']))
action = torch.zeros((len(obs['reward']),) + self.act_spec.shape, device=self.device)
else:
latent, action = state
embed = self.wm.encoder(self.wm.preprocess(obs))
should_sample = (not eval_mode) or (not self.cfg.eval_state_mean)
latent, _ = self.wm.rssm.obs_step(latent, action, embed, obs['is_first'], should_sample)
feat = self.wm.rssm.get_feat(latent)
if eval_mode:
actor = self._acting_behavior.actor(feat)
try:
action = actor.mean
except:
action = actor._mean
else:
actor = self._acting_behavior.actor(feat)
action = actor.sample()
new_state = (latent, action)
return action.cpu().numpy()[0], new_state
def update_wm(self, data, step):
metrics = {}
state, outputs, mets = self.wm.update(data, state=None)
outputs['is_terminal'] = data['is_terminal']
metrics.update(mets)
return state, outputs, metrics
def update_acting_behavior(self, state=None, outputs=None, metrics={}, data=None, reward_fn=None):
if self.cfg.only_random_actions:
return {}, metrics
if outputs is not None:
post = outputs['post']
is_terminal = outputs['is_terminal']
else:
data = self.wm.preprocess(data)
embed = self.wm.encoder(data)
post, _ = self.wm.rssm.observe(
embed, data['action'], data['is_first'])
is_terminal = data['is_terminal']
#
start = {k: stop_gradient(v) for k,v in post.items()}
if reward_fn is None:
acting_reward_fn = lambda seq: globals()[self.cfg.acting_reward_fn](self, seq) #.mode()
else:
acting_reward_fn = lambda seq: reward_fn(self, seq) #.mode()
metrics.update(self._acting_behavior.update(self.wm, start, is_terminal, acting_reward_fn))
return start, metrics
def update(self, data, step):
state, outputs, metrics = self.update_wm(data, step)
start, metrics = self.update_acting_behavior(state, outputs, metrics, data)
return state, metrics
def report(self, data):
report = {}
data = self.wm.preprocess(data)
for key in self.wm.heads['decoder'].cnn_keys:
name = key.replace('/', '_')
report[f'openl_{name}'] = self.wm.video_pred(data, key)
for fn in getattr(self.cfg, 'additional_report_fns', []):
call_fn = globals()[fn]
additional_report = call_fn(self, data)
report.update(additional_report)
return report
def get_meta_specs(self):
return tuple()
def init_meta(self):
return OrderedDict()
def update_meta(self, meta, global_step, time_step, finetune=False):
return meta
class WorldModel(Module):
def __init__(self, config, obs_space, act_dim,):
super().__init__()
shapes = {k: tuple(v.shape) for k, v in obs_space.items()}
self.shapes = shapes
self.cfg = config
self.device = config.device
self.encoder = common.Encoder(shapes, **config.encoder)
# Computing embed dim
with torch.no_grad():
zeros = {k: torch.zeros( (1,) + v) for k, v in shapes.items()}
outs = self.encoder(zeros)
embed_dim = outs.shape[1]
self.embed_dim = embed_dim
self.rssm = common.EnsembleRSSM(**config.rssm, action_dim=act_dim, embed_dim=embed_dim, device=self.device,)
self.heads = {}
self._use_amp = (config.precision == 16)
self.inp_size = self.rssm.get_feat_size()
self.decoder_input_fn = getattr(self.rssm, f'get_{config.decoder_inputs}')
self.decoder_input_size = getattr(self.rssm, f'get_{config.decoder_inputs}_size')()
self.heads['decoder'] = common.Decoder(shapes, **config.decoder, embed_dim=self.decoder_input_size, image_dist=config.image_dist)
self.heads['reward'] = common.MLP(self.inp_size, (1,), **config.reward_head)
# zero init
with torch.no_grad():
for p in self.heads['reward']._out.parameters():
p.data = p.data * 0
#
if config.pred_discount:
self.heads['discount'] = common.MLP(self.inp_size, (1,), **config.discount_head)
for name in config.grad_heads:
assert name in self.heads, name
self.grad_heads = config.grad_heads
self.heads = nn.ModuleDict(self.heads)
self.model_opt = common.Optimizer('model', self.parameters(), **config.model_opt, use_amp=self._use_amp)
self.e2e_update_fns = {}
self.detached_update_fns = {}
self.eval()
def add_module_to_update(self, name, module, update_fn, detached=False):
self.add_module(name, module)
if detached:
self.detached_update_fns[name] = update_fn
else:
self.e2e_update_fns[name] = update_fn
self.model_opt = common.Optimizer('model', self.parameters(), **self.cfg.model_opt, use_amp=self._use_amp)
def update(self, data, state=None):
self.train()
with common.RequiresGrad(self):
with torch.cuda.amp.autocast(enabled=self._use_amp):
if getattr(self.cfg, "freeze_decoder", False):
self.heads['decoder'].requires_grad_(False)
if getattr(self.cfg, "freeze_post", False) or getattr(self.cfg, "freeze_model", False):
self.heads['decoder'].requires_grad_(False)
self.encoder.requires_grad_(False)
# Updating only prior
self.grad_heads = []
self.rssm.requires_grad_(False)
if not getattr(self.cfg, "freeze_model", False):
self.rssm._ensemble_img_out.requires_grad_(True)
self.rssm._ensemble_img_dist.requires_grad_(True)
model_loss, state, outputs, metrics = self.loss(data, state)
model_loss, metrics = self.update_additional_e2e_modules(data, outputs, model_loss, metrics)
metrics.update(self.model_opt(model_loss, self.parameters()))
if len(self.detached_update_fns) > 0:
detached_loss, metrics = self.update_additional_detached_modules(data, outputs, metrics)
self.eval()
return state, outputs, metrics
def update_additional_detached_modules(self, data, outputs, metrics):
# additional detached losses
detached_loss = 0
for k in self.detached_update_fns:
detached_module = getattr(self, k)
with common.RequiresGrad(detached_module):
with torch.cuda.amp.autocast(enabled=self._use_amp):
add_loss, add_metrics = self.detached_update_fns[k](self, k, data, outputs, metrics)
metrics.update(add_metrics)
opt_metrics = self.model_opt(add_loss, detached_module.parameters())
metrics.update({ f'{k}_{m}' : opt_metrics[m] for m in opt_metrics})
return detached_loss, metrics
def update_additional_e2e_modules(self, data, outputs, model_loss, metrics):
# additional e2e losses
for k in self.e2e_update_fns:
add_loss, add_metrics = self.e2e_update_fns[k](self, k, data, outputs, metrics)
model_loss += add_loss
metrics.update(add_metrics)
return model_loss, metrics
def observe_data(self, data, state=None):
data = self.preprocess(data)
embed = self.encoder(data)
post, prior = self.rssm.observe(
embed, data['action'], data['is_first'], state)
kl_loss, kl_value = self.rssm.kl_loss(post, prior, **self.cfg.kl)
outs = dict(embed=embed, post=post, prior=prior, is_terminal=data['is_terminal'])
return outs, { 'model_kl' : kl_value.mean() }
def loss(self, data, state=None):
data = self.preprocess(data)
embed = self.encoder(data)
post, prior = self.rssm.observe(
embed, data['action'], data['is_first'], state)
kl_loss, kl_value = self.rssm.kl_loss(post, prior, **self.cfg.kl)
assert len(kl_loss.shape) == 0 or (len(kl_loss.shape) == 1 and kl_loss.shape[0] == 1), kl_loss.shape
likes = {}
losses = {'kl': kl_loss}
feat = self.rssm.get_feat(post)
for name, head in self.heads.items():
grad_head = (name in self.grad_heads)
if name == 'decoder':
inp = self.decoder_input_fn(post)
else:
inp = feat
inp = inp if grad_head else stop_gradient(inp)
out = head(inp)
dists = out if isinstance(out, dict) else {name: out}
for key, dist in dists.items():
like = dist.log_prob(data[key])
likes[key] = like
losses[key] = -like.mean()
model_loss = sum(
self.cfg.loss_scales.get(k, 1.0) * v for k, v in losses.items())
outs = dict(
embed=embed, feat=feat, post=post,
prior=prior, likes=likes, kl=kl_value)
metrics = {f'{name}_loss': value for name, value in losses.items()}
metrics['model_kl'] = kl_value.mean()
metrics['prior_ent'] = self.rssm.get_dist(prior).entropy().mean()
metrics['post_ent'] = self.rssm.get_dist(post).entropy().mean()
last_state = {k: v[:, -1] for k, v in post.items()}
return model_loss, last_state, outs, metrics
def imagine(self, policy, start, is_terminal, horizon, task_cond=None, eval_policy=False):
flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
start['feat'] = self.rssm.get_feat(start)
inp = start['feat'] if task_cond is None else torch.cat([start['feat'], task_cond], dim=-1)
policy_dist = policy(inp)
start['action'] = torch.zeros_like(policy_dist.sample(), device=self.device) #.mode())
seq = {k: [v] for k, v in start.items()}
if task_cond is not None: seq['task'] = [task_cond]
for _ in range(horizon):
inp = seq['feat'][-1] if task_cond is None else torch.cat([seq['feat'][-1], task_cond], dim=-1)
policy_dist = policy(stop_gradient(inp))
action = policy_dist.sample() if not eval_policy else policy_dist.mean
state = self.rssm.img_step({k: v[-1] for k, v in seq.items()}, action)
feat = self.rssm.get_feat(state)
for key, value in {**state, 'action': action, 'feat': feat}.items():
seq[key].append(value)
if task_cond is not None: seq['task'].append(task_cond)
# shape will be (T, B, *DIMS)
seq = {k: torch.stack(v, 0) for k, v in seq.items()}
if 'discount' in self.heads:
disc = self.heads['discount'](seq['feat']).mean()
if is_terminal is not None:
# Override discount prediction for the first step with the true
# discount factor from the replay buffer.
true_first = 1.0 - flatten(is_terminal)
disc = torch.cat([true_first[None], disc[1:]], 0)
else:
disc = torch.ones(list(seq['feat'].shape[:-1]) + [1], device=self.device)
seq['discount'] = disc * self.cfg.discount
# Shift discount factors because they imply whether the following state
# will be valid, not whether the current state is valid.
seq['weight'] = torch.cumprod(torch.cat([torch.ones_like(disc[:1], device=self.device), disc[:-1]], 0), 0)
return seq
def preprocess(self, obs):
obs = obs.copy()
for key, value in obs.items():
if key.startswith('log_'):
continue
if value.dtype in [np.uint8, torch.uint8]:
value = value / 255.0 - 0.5
obs[key] = value
obs['reward'] = {
'identity': nn.Identity(),
'sign': torch.sign,
'tanh': torch.tanh,
}[self.cfg.clip_rewards](obs['reward'])
obs['discount'] = (1.0 - obs['is_terminal'].float())
if len(obs['discount'].shape) < len(obs['reward'].shape):
obs['discount'] = obs['discount'].unsqueeze(-1)
return obs
def video_pred(self, data, key, nvid=8):
decoder = self.heads['decoder'] # B, T, C, H, W
truth = data[key][:nvid] + 0.5
embed = self.encoder(data)
states, _ = self.rssm.observe(
embed[:nvid, :5], data['action'][:nvid, :5], data['is_first'][:nvid, :5])
recon = decoder(self.decoder_input_fn(states))[key].mean[:nvid] # mode
init = {k: v[:, -1] for k, v in states.items()}
prior = self.rssm.imagine(data['action'][:nvid, 5:], init)
prior_recon = decoder(self.decoder_input_fn(prior))[key].mean # mode
model = torch.clip(torch.cat([recon[:, :5] + 0.5, prior_recon + 0.5], 1), 0, 1)
error = (model - truth + 1) / 2
video = torch.cat([truth, model, error], 3)
B, T, C, H, W = video.shape
return video
class ActorCritic(Module):
def __init__(self, config, act_spec, feat_size, name=''):
super().__init__()
self.name = name
self.cfg = config
self.act_spec = act_spec
self._use_amp = (config.precision == 16)
self.device = config.device
if getattr(self.cfg, 'discrete_actions', False):
self.cfg.actor.dist = 'onehot'
self.actor_grad = getattr(self.cfg, f'{self.name}_actor_grad'.strip('_'))
inp_size = feat_size
self.actor = common.MLP(inp_size, act_spec.shape[0], **self.cfg.actor)
self.critic = common.MLP(inp_size, (1,), **self.cfg.critic)
if self.cfg.slow_target:
self._target_critic = common.MLP(inp_size, (1,), **self.cfg.critic)
self._updates = 0 # tf.Variable(0, tf.int64)
else:
self._target_critic = self.critic
self.actor_opt = common.Optimizer('actor', self.actor.parameters(), **self.cfg.actor_opt, use_amp=self._use_amp)
self.critic_opt = common.Optimizer('critic', self.critic.parameters(), **self.cfg.critic_opt, use_amp=self._use_amp)
if self.cfg.reward_ema:
# register ema_vals to nn.Module for enabling torch.save and torch.load
self.register_buffer("ema_vals", torch.zeros((2,)).to(self.device))
self.reward_ema = common.RewardEMA(device=self.device)
self.rewnorm = common.StreamNorm(momentum=1, scale=1.0, device=self.device)
else:
self.rewnorm = common.StreamNorm(**self.cfg.reward_norm, device=self.device)
# zero init
with torch.no_grad():
for p in self.critic._out.parameters():
p.data = p.data * 0
# hard copy critic initial params
for s, d in zip(self.critic.parameters(), self._target_critic.parameters()):
d.data = s.data
#
def update(self, world_model, start, is_terminal, reward_fn):
metrics = {}
hor = self.cfg.imag_horizon
# The weights are is_terminal flags for the imagination start states.
# Technically, they should multiply the losses from the second trajectory
# step onwards, which is the first imagined step. However, we are not
# training the action that led into the first step anyway, so we can use
# them to scale the whole sequence.
with common.RequiresGrad(self.actor):
with torch.cuda.amp.autocast(enabled=self._use_amp):
seq = world_model.imagine(self.actor, start, is_terminal, hor)
reward = reward_fn(seq)
seq['reward'], mets1 = self.rewnorm(reward)
mets1 = {f'reward_{k}': v for k, v in mets1.items()}
target, mets2, baseline = self.target(seq)
actor_loss, mets3 = self.actor_loss(seq, target, baseline)
metrics.update(self.actor_opt(actor_loss, self.actor.parameters()))
with common.RequiresGrad(self.critic):
with torch.cuda.amp.autocast(enabled=self._use_amp):
seq = {k: stop_gradient(v) for k,v in seq.items()}
critic_loss, mets4 = self.critic_loss(seq, target)
metrics.update(self.critic_opt(critic_loss, self.critic.parameters()))
metrics.update(**mets1, **mets2, **mets3, **mets4)
self.update_slow_target() # Variables exist after first forward pass.
return { f'{self.name}_{k}'.strip('_') : v for k,v in metrics.items() }
def actor_loss(self, seq, target, baseline): #, step):
# Two state-actions are lost at the end of the trajectory, one for the boostrap
# value prediction and one because the corresponding action does not lead
# anywhere anymore. One target is lost at the start of the trajectory
# because the initial state comes from the replay buffer.
policy = self.actor(stop_gradient(seq['feat'][:-2])) # actions are the ones in [1:-1]
metrics = {}
if self.cfg.reward_ema:
offset, scale = self.reward_ema(target, self.ema_vals)
normed_target = (target - offset) / scale
normed_baseline = (baseline - offset) / scale
# adv = normed_target - normed_baseline
metrics['normed_target_mean'] = normed_target.mean()
metrics['normed_target_std'] = normed_target.std()
metrics["reward_ema_005"] = self.ema_vals[0]
metrics["reward_ema_095"] = self.ema_vals[1]
else:
normed_target = target
normed_baseline = baseline
if self.actor_grad == 'dynamics':
objective = normed_target[1:]
elif self.actor_grad == 'reinforce':
advantage = normed_target[1:] - normed_baseline[1:]
objective = policy.log_prob(stop_gradient(seq['action'][1:-1]))[:,:,None] * advantage
else:
raise NotImplementedError(self.actor_grad)
ent = policy.entropy()[:,:,None]
ent_scale = self.cfg.actor_ent
objective += ent_scale * ent
metrics['actor_ent'] = ent.mean()
metrics['actor_ent_scale'] = ent_scale
weight = stop_gradient(seq['weight'])
actor_loss = -(weight[:-2] * objective).mean()
return actor_loss, metrics
def critic_loss(self, seq, target):
feat = seq['feat'][:-1]
target = stop_gradient(target)
weight = stop_gradient(seq['weight'])
dist = self.critic(feat)
critic_loss = -(dist.log_prob(target)[:,:,None] * weight[:-1]).mean()
metrics = {'critic': dist.mean.mean() }
return critic_loss, metrics
def target(self, seq):
reward = seq['reward']
disc = seq['discount']
value = self._target_critic(seq['feat']).mean
# Skipping last time step because it is used for bootstrapping.
target = common.lambda_return(
reward[:-1], value[:-1], disc[:-1],
bootstrap=value[-1],
lambda_=self.cfg.discount_lambda,
axis=0)
metrics = {}
metrics['critic_slow'] = value.mean()
metrics['critic_target'] = target.mean()
return target, metrics, value[:-1]
def update_slow_target(self):
if self.cfg.slow_target:
if self._updates % self.cfg.slow_target_update == 0:
mix = 1.0 if self._updates == 0 else float(
self.cfg.slow_target_fraction)
for s, d in zip(self.critic.parameters(), self._target_critic.parameters()):
d.data = mix * s.data + (1 - mix) * d.data
self._updates += 1