|
from copy import deepcopy |
|
from typing import Tuple, Optional, List, Dict |
|
from easydict import EasyDict |
|
import pickle |
|
import os |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
|
|
from ding.utils import REWARD_MODEL_REGISTRY |
|
from ding.utils import SequenceType |
|
from ding.model.common import FCEncoder |
|
from ding.utils import build_logger |
|
from ding.utils.data import default_collate |
|
|
|
from .base_reward_model import BaseRewardModel |
|
from .rnd_reward_model import collect_states |
|
|
|
|
|
class TrexConvEncoder(nn.Module): |
|
r""" |
|
Overview: |
|
The ``Convolution Encoder`` used in models. Used to encoder raw 2-dim observation. |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
obs_shape: SequenceType, |
|
hidden_size_list: SequenceType = [16, 16, 16, 16, 64, 1], |
|
activation: Optional[nn.Module] = nn.LeakyReLU() |
|
) -> None: |
|
r""" |
|
Overview: |
|
Init the Trex Convolution Encoder according to arguments. TrexConvEncoder is different \ |
|
from the ConvEncoder in model.common.encoder, their stride and kernel size parameters \ |
|
are different |
|
Arguments: |
|
- obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, some ``output size`` |
|
- hidden_size_list (:obj:`SequenceType`): The collection of ``hidden_size`` |
|
- activation (:obj:`nn.Module`): |
|
The type of activation to use in the conv ``layers``, |
|
if ``None`` then default set to ``nn.LeakyReLU()`` |
|
""" |
|
super(TrexConvEncoder, self).__init__() |
|
self.obs_shape = obs_shape |
|
self.act = activation |
|
self.hidden_size_list = hidden_size_list |
|
|
|
layers = [] |
|
kernel_size = [7, 5, 3, 3] |
|
stride = [3, 2, 1, 1] |
|
input_size = obs_shape[0] |
|
for i in range(len(kernel_size)): |
|
layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i])) |
|
layers.append(self.act) |
|
input_size = hidden_size_list[i] |
|
layers.append(nn.Flatten()) |
|
self.main = nn.Sequential(*layers) |
|
|
|
flatten_size = self._get_flatten_size() |
|
self.mid = nn.Sequential( |
|
nn.Linear(flatten_size, hidden_size_list[-2]), self.act, |
|
nn.Linear(hidden_size_list[-2], hidden_size_list[-1]) |
|
) |
|
|
|
def _get_flatten_size(self) -> int: |
|
r""" |
|
Overview: |
|
Get the encoding size after ``self.main`` to get the number of ``in-features`` to feed to ``nn.Linear``. |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): Encoded Tensor after ``self.main`` |
|
Returns: |
|
- outputs (:obj:`torch.Tensor`): Size int, also number of in-feature |
|
""" |
|
test_data = torch.randn(1, *self.obs_shape) |
|
with torch.no_grad(): |
|
output = self.main(test_data) |
|
return output.shape[1] |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
r""" |
|
Overview: |
|
Return embedding tensor of the env observation |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): Env raw observation |
|
Returns: |
|
- outputs (:obj:`torch.Tensor`): Embedding tensor |
|
""" |
|
x = self.main(x) |
|
x = self.mid(x) |
|
return x |
|
|
|
|
|
class TrexModel(nn.Module): |
|
|
|
def __init__(self, obs_shape): |
|
super(TrexModel, self).__init__() |
|
if isinstance(obs_shape, int) or len(obs_shape) == 1: |
|
self.encoder = nn.Sequential(FCEncoder(obs_shape, [512, 64]), nn.Linear(64, 1)) |
|
|
|
elif len(obs_shape) == 3: |
|
self.encoder = TrexConvEncoder(obs_shape) |
|
else: |
|
raise KeyError( |
|
"not support obs_shape for pre-defined encoder: {}, please customize your own Trex model". |
|
format(obs_shape) |
|
) |
|
|
|
def cum_return(self, traj: torch.Tensor, mode: str = 'sum') -> Tuple[torch.Tensor, torch.Tensor]: |
|
'''calculate cumulative return of trajectory''' |
|
r = self.encoder(traj) |
|
if mode == 'sum': |
|
sum_rewards = torch.sum(r) |
|
sum_abs_rewards = torch.sum(torch.abs(r)) |
|
return sum_rewards, sum_abs_rewards |
|
elif mode == 'batch': |
|
return r, torch.abs(r) |
|
else: |
|
raise KeyError("not support mode: {}, please choose mode=sum or mode=batch".format(mode)) |
|
|
|
def forward(self, traj_i: torch.Tensor, traj_j: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
'''compute cumulative return for each trajectory and return logits''' |
|
cum_r_i, abs_r_i = self.cum_return(traj_i) |
|
cum_r_j, abs_r_j = self.cum_return(traj_j) |
|
return torch.cat((cum_r_i.unsqueeze(0), cum_r_j.unsqueeze(0)), 0), abs_r_i + abs_r_j |
|
|
|
|
|
@REWARD_MODEL_REGISTRY.register('trex') |
|
class TrexRewardModel(BaseRewardModel): |
|
""" |
|
Overview: |
|
The Trex reward model class (https://arxiv.org/pdf/1904.06387.pdf) |
|
Interface: |
|
``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \ |
|
``__init__``, ``_train``, |
|
Config: |
|
== ==================== ====== ============= ============================================ ============= |
|
ID Symbol Type Default Value Description Other(Shape) |
|
== ==================== ====== ============= ============================================ ============= |
|
1 ``type`` str trex | Reward model register name, refer | |
|
| to registry ``REWARD_MODEL_REGISTRY`` | |
|
3 | ``learning_rate`` float 0.00001 | learning rate for optimizer | |
|
4 | ``update_per_`` int 100 | Number of updates per collect | |
|
| ``collect`` | | |
|
5 | ``num_trajs`` int 0 | Number of downsampled full trajectories | |
|
6 | ``num_snippets`` int 6000 | Number of short subtrajectories to sample | |
|
== ==================== ====== ============= ============================================ ============= |
|
""" |
|
config = dict( |
|
|
|
type='trex', |
|
|
|
learning_rate=1e-5, |
|
|
|
|
|
|
|
update_per_collect=100, |
|
|
|
num_trajs=0, |
|
|
|
num_snippets=6000, |
|
) |
|
|
|
def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: |
|
""" |
|
Overview: |
|
Initialize ``self.`` See ``help(type(self))`` for accurate signature. |
|
Arguments: |
|
- cfg (:obj:`EasyDict`): Training config |
|
- device (:obj:`str`): Device usage, i.e. "cpu" or "cuda" |
|
- tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary |
|
""" |
|
super(TrexRewardModel, self).__init__() |
|
self.cfg = config |
|
assert device in ["cpu", "cuda"] or "cuda" in device |
|
self.device = device |
|
self.tb_logger = tb_logger |
|
self.reward_model = TrexModel(self.cfg.policy.model.obs_shape) |
|
self.reward_model.to(self.device) |
|
self.pre_expert_data = [] |
|
self.train_data = [] |
|
self.expert_data_loader = None |
|
self.opt = optim.Adam(self.reward_model.parameters(), config.reward_model.learning_rate) |
|
self.train_iter = 0 |
|
self.learning_returns = [] |
|
self.training_obs = [] |
|
self.training_labels = [] |
|
self.num_trajs = self.cfg.reward_model.num_trajs |
|
self.num_snippets = self.cfg.reward_model.num_snippets |
|
|
|
self.min_snippet_length = config.reward_model.min_snippet_length |
|
|
|
self.max_snippet_length = config.reward_model.max_snippet_length |
|
self.l1_reg = 0 |
|
self.data_for_save = {} |
|
self._logger, self._tb_logger = build_logger( |
|
path='./{}/log/{}'.format(self.cfg.exp_name, 'trex_reward_model'), name='trex_reward_model' |
|
) |
|
self.load_expert_data() |
|
|
|
def load_expert_data(self) -> None: |
|
""" |
|
Overview: |
|
Getting the expert data. |
|
Effects: |
|
This is a side effect function which updates the expert data attribute \ |
|
(i.e. ``self.expert_data``) with ``fn:concat_state_action_pairs`` |
|
""" |
|
with open(os.path.join(self.cfg.exp_name, 'episodes_data.pkl'), 'rb') as f: |
|
self.pre_expert_data = pickle.load(f) |
|
with open(os.path.join(self.cfg.exp_name, 'learning_returns.pkl'), 'rb') as f: |
|
self.learning_returns = pickle.load(f) |
|
|
|
self.create_training_data() |
|
self._logger.info("num_training_obs: {}".format(len(self.training_obs))) |
|
self._logger.info("num_labels: {}".format(len(self.training_labels))) |
|
|
|
def create_training_data(self): |
|
num_trajs = self.num_trajs |
|
num_snippets = self.num_snippets |
|
min_snippet_length = self.min_snippet_length |
|
max_snippet_length = self.max_snippet_length |
|
|
|
demo_lengths = [] |
|
for i in range(len(self.pre_expert_data)): |
|
demo_lengths.append([len(d) for d in self.pre_expert_data[i]]) |
|
|
|
self._logger.info("demo_lengths: {}".format(demo_lengths)) |
|
max_snippet_length = min(np.min(demo_lengths), max_snippet_length) |
|
self._logger.info("min snippet length: {}".format(min_snippet_length)) |
|
self._logger.info("max snippet length: {}".format(max_snippet_length)) |
|
|
|
|
|
max_traj_length = 0 |
|
num_bins = len(self.pre_expert_data) |
|
assert num_bins >= 2 |
|
|
|
|
|
si = np.random.randint(6, size=num_trajs) |
|
sj = np.random.randint(6, size=num_trajs) |
|
step = np.random.randint(3, 7, size=num_trajs) |
|
for n in range(num_trajs): |
|
|
|
bi, bj = np.random.choice(num_bins, size=(2, ), replace=False) |
|
ti = np.random.choice(len(self.pre_expert_data[bi])) |
|
tj = np.random.choice(len(self.pre_expert_data[bj])) |
|
|
|
traj_i = self.pre_expert_data[bi][ti][si[n]::step[n]] |
|
traj_j = self.pre_expert_data[bj][tj][sj[n]::step[n]] |
|
|
|
label = int(bi <= bj) |
|
|
|
self.training_obs.append((traj_i, traj_j)) |
|
self.training_labels.append(label) |
|
max_traj_length = max(max_traj_length, len(traj_i), len(traj_j)) |
|
|
|
|
|
rand_length = np.random.randint(min_snippet_length, max_snippet_length, size=num_snippets) |
|
for n in range(num_snippets): |
|
|
|
bi, bj = np.random.choice(num_bins, size=(2, ), replace=False) |
|
ti = np.random.choice(len(self.pre_expert_data[bi])) |
|
tj = np.random.choice(len(self.pre_expert_data[bj])) |
|
|
|
|
|
|
|
min_length = min(len(self.pre_expert_data[bi][ti]), len(self.pre_expert_data[bj][tj])) |
|
if bi < bj: |
|
ti_start = np.random.randint(min_length - rand_length[n] + 1) |
|
|
|
tj_start = np.random.randint(ti_start, len(self.pre_expert_data[bj][tj]) - rand_length[n] + 1) |
|
else: |
|
tj_start = np.random.randint(min_length - rand_length[n] + 1) |
|
|
|
ti_start = np.random.randint(tj_start, len(self.pre_expert_data[bi][ti]) - rand_length[n] + 1) |
|
|
|
traj_i = self.pre_expert_data[bi][ti][ti_start:ti_start + rand_length[n]:2] |
|
traj_j = self.pre_expert_data[bj][tj][tj_start:tj_start + rand_length[n]:2] |
|
|
|
max_traj_length = max(max_traj_length, len(traj_i), len(traj_j)) |
|
label = int(bi <= bj) |
|
self.training_obs.append((traj_i, traj_j)) |
|
self.training_labels.append(label) |
|
self._logger.info(("maximum traj length: {}".format(max_traj_length))) |
|
return self.training_obs, self.training_labels |
|
|
|
def _train(self): |
|
|
|
device = self.device |
|
|
|
self._logger.info("device: {}".format(device)) |
|
training_inputs, training_outputs = self.training_obs, self.training_labels |
|
loss_criterion = nn.CrossEntropyLoss() |
|
|
|
cum_loss = 0.0 |
|
training_data = list(zip(training_inputs, training_outputs)) |
|
for epoch in range(self.cfg.reward_model.update_per_collect): |
|
np.random.shuffle(training_data) |
|
training_obs, training_labels = zip(*training_data) |
|
for i in range(len(training_labels)): |
|
|
|
|
|
traj_i, traj_j = training_obs[i] |
|
traj_i = np.array(traj_i) |
|
traj_j = np.array(traj_j) |
|
traj_i = torch.from_numpy(traj_i).float().to(device) |
|
traj_j = torch.from_numpy(traj_j).float().to(device) |
|
|
|
|
|
labels = torch.tensor([training_labels[i]]).to(device) |
|
|
|
|
|
outputs, abs_rewards = self.reward_model.forward(traj_i, traj_j) |
|
outputs = outputs.unsqueeze(0) |
|
loss = loss_criterion(outputs, labels) + self.l1_reg * abs_rewards |
|
self.opt.zero_grad() |
|
loss.backward() |
|
self.opt.step() |
|
|
|
|
|
item_loss = loss.item() |
|
cum_loss += item_loss |
|
if i % 100 == 99: |
|
self._logger.info("[epoch {}:{}] loss {}".format(epoch, i, cum_loss)) |
|
self._logger.info("abs_returns: {}".format(abs_rewards)) |
|
cum_loss = 0.0 |
|
self._logger.info("check pointing") |
|
if not os.path.exists(os.path.join(self.cfg.exp_name, 'ckpt_reward_model')): |
|
os.makedirs(os.path.join(self.cfg.exp_name, 'ckpt_reward_model')) |
|
torch.save(self.reward_model.state_dict(), os.path.join(self.cfg.exp_name, 'ckpt_reward_model/latest.pth.tar')) |
|
self._logger.info("finished training") |
|
|
|
def train(self): |
|
self._train() |
|
|
|
sorted_returns = sorted(self.learning_returns, key=lambda s: s[0]) |
|
demonstrations = [ |
|
x for _, x in sorted(zip(self.learning_returns, self.pre_expert_data), key=lambda pair: pair[0][0]) |
|
] |
|
with torch.no_grad(): |
|
pred_returns = [self.predict_traj_return(self.reward_model, traj[0]) for traj in demonstrations] |
|
for i, p in enumerate(pred_returns): |
|
self._logger.info("{} {} {}".format(i, p, sorted_returns[i][0])) |
|
info = { |
|
"demo_length": [len(d[0]) for d in self.pre_expert_data], |
|
"min_snippet_length": self.min_snippet_length, |
|
"max_snippet_length": min(np.min([len(d[0]) for d in self.pre_expert_data]), self.max_snippet_length), |
|
"len_num_training_obs": len(self.training_obs), |
|
"lem_num_labels": len(self.training_labels), |
|
"accuracy": self.calc_accuracy(self.reward_model, self.training_obs, self.training_labels), |
|
} |
|
self._logger.info( |
|
"accuracy and comparison:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])) |
|
) |
|
|
|
def predict_traj_return(self, net, traj): |
|
device = self.device |
|
|
|
|
|
with torch.no_grad(): |
|
rewards_from_obs = net.cum_return( |
|
torch.from_numpy(np.array(traj)).float().to(device), mode='batch' |
|
)[0].squeeze().tolist() |
|
|
|
|
|
return sum(rewards_from_obs) |
|
|
|
def calc_accuracy(self, reward_network, training_inputs, training_outputs): |
|
device = self.device |
|
loss_criterion = nn.CrossEntropyLoss() |
|
num_correct = 0. |
|
with torch.no_grad(): |
|
for i in range(len(training_inputs)): |
|
label = training_outputs[i] |
|
traj_i, traj_j = training_inputs[i] |
|
traj_i = np.array(traj_i) |
|
traj_j = np.array(traj_j) |
|
traj_i = torch.from_numpy(traj_i).float().to(device) |
|
traj_j = torch.from_numpy(traj_j).float().to(device) |
|
|
|
|
|
outputs, abs_return = reward_network.forward(traj_i, traj_j) |
|
_, pred_label = torch.max(outputs, 0) |
|
if pred_label.item() == label: |
|
num_correct += 1. |
|
return num_correct / len(training_inputs) |
|
|
|
def pred_data(self, data): |
|
obs = [default_collate(data[i])['obs'] for i in range(len(data))] |
|
res = [torch.sum(default_collate(data[i])['reward']).item() for i in range(len(data))] |
|
pred_returns = [self.predict_traj_return(self.reward_model, obs[i]) for i in range(len(obs))] |
|
return {'real': res, 'pred': pred_returns} |
|
|
|
def estimate(self, data: list) -> List[Dict]: |
|
""" |
|
Overview: |
|
Estimate reward by rewriting the reward key in each row of the data. |
|
Arguments: |
|
- data (:obj:`list`): the list of data used for estimation, with at least \ |
|
``obs`` and ``action`` keys. |
|
Effects: |
|
- This is a side effect function which updates the reward values in place. |
|
""" |
|
|
|
|
|
train_data_augmented = self.reward_deepcopy(data) |
|
|
|
res = collect_states(train_data_augmented) |
|
res = torch.stack(res).to(self.device) |
|
with torch.no_grad(): |
|
sum_rewards, sum_abs_rewards = self.reward_model.cum_return(res, mode='batch') |
|
|
|
for item, rew in zip(train_data_augmented, sum_rewards): |
|
item['reward'] = rew |
|
|
|
return train_data_augmented |
|
|
|
def collect_data(self, data: list) -> None: |
|
""" |
|
Overview: |
|
Collecting training data formatted by ``fn:concat_state_action_pairs``. |
|
Arguments: |
|
- data (:obj:`Any`): Raw training data (e.g. some form of states, actions, obs, etc) |
|
Effects: |
|
- This is a side effect function which updates the data attribute in ``self`` |
|
""" |
|
pass |
|
|
|
def clear_data(self) -> None: |
|
""" |
|
Overview: |
|
Clearing training data. \ |
|
This is a side effect function which clears the data attribute in ``self`` |
|
""" |
|
self.training_obs.clear() |
|
self.training_labels.clear() |
|
|