genrl / agent /plan2explore.py
mazpie's picture
Initial commit
2d9a728
raw
history blame
No virus
4.22 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from agent.dreamer import DreamerAgent, stop_gradient
import agent.dreamer_utils as common
class Disagreement(nn.Module):
def __init__(self, obs_dim, action_dim, hidden_dim, n_models=5, pred_dim=None):
super().__init__()
if pred_dim is None: pred_dim = obs_dim
self.ensemble = nn.ModuleList([
nn.Sequential(nn.Linear(obs_dim + action_dim, hidden_dim),
nn.ReLU(), nn.Linear(hidden_dim, pred_dim))
for _ in range(n_models)
])
def forward(self, obs, action, next_obs):
assert obs.shape[0] == next_obs.shape[0]
assert obs.shape[0] == action.shape[0]
errors = []
for model in self.ensemble:
next_obs_hat = model(torch.cat([obs, action], dim=-1))
model_error = torch.norm(next_obs - next_obs_hat,
dim=-1,
p=2,
keepdim=True)
errors.append(model_error)
return torch.cat(errors, dim=1)
def get_disagreement(self, obs, action):
assert obs.shape[0] == action.shape[0]
preds = []
for model in self.ensemble:
next_obs_hat = model(torch.cat([obs, action], dim=-1))
preds.append(next_obs_hat)
preds = torch.stack(preds, dim=0)
return torch.var(preds, dim=0).mean(dim=-1)
class Plan2Explore(DreamerAgent):
def __init__(self, **kwargs):
super().__init__(**kwargs)
in_dim = self.wm.inp_size
pred_dim = self.wm.embed_dim
self.hidden_dim = pred_dim
self.reward_free = True
self.disagreement = Disagreement(in_dim, self.act_dim,
self.hidden_dim, pred_dim=pred_dim).to(self.device)
# optimizers
self.disagreement_opt = common.Optimizer('disagreement', self.disagreement.parameters(), **self.cfg.model_opt, use_amp=self._use_amp)
self.disagreement.train()
self.requires_grad_(requires_grad=False)
def update_disagreement(self, obs, action, next_obs, step):
metrics = dict()
error = self.disagreement(obs, action, next_obs)
loss = error.mean()
metrics.update(self.disagreement_opt(loss, self.disagreement.parameters()))
metrics['disagreement_loss'] = loss.item()
return metrics
def compute_intr_reward(self, seq):
obs, action = seq['feat'][:-1], stop_gradient(seq['action'][1:])
intr_rew = torch.zeros(list(seq['action'].shape[:-1]) + [1], device=self.device)
if len(action.shape) > 2:
B, T, _ = action.shape
obs = obs.reshape(B*T, -1)
action = action.reshape(B*T, -1)
reward = self.disagreement.get_disagreement(obs, action).reshape(B, T, 1)
else:
reward = self.disagreement.get_disagreement(obs, action).unsqueeze(-1)
intr_rew[1:] = reward
return intr_rew
def update(self, data, step):
metrics = {}
B, T, _ = data['action'].shape
state, outputs, mets = self.wm.update(data, state=None)
metrics.update(mets)
start = outputs['post']
start = {k: stop_gradient(v) for k,v in start.items()}
if self.reward_free:
T = T-1
inp = stop_gradient(outputs['feat'][:, :-1]).reshape(B*T, -1)
action = data['action'][:, 1:].reshape(B*T, -1)
out = stop_gradient(outputs['embed'][:,1:]).reshape(B*T,-1)
with common.RequiresGrad(self.disagreement):
with torch.cuda.amp.autocast(enabled=self._use_amp):
metrics.update(
self.update_disagreement(inp, action, out, step))
metrics.update(self._acting_behavior.update(
self.wm, start, data['is_terminal'], reward_fn=self.compute_intr_reward))
else:
reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean
metrics.update(self._acting_behavior.update(
self.wm, start, data['is_terminal'], reward_fn))
return state, metrics