from typing import Dict |
import os |
import torch |
import torch.nn as nn |
import numpy as np |
import gym |
from gym import spaces |
from ditk import logging |
from ding.envs import DingEnvWrapper, EvalEpisodeReturnWrapper, \ |
BaseEnvManagerV2 |
from ding.config import compile_config |
from ding.policy import PPOPolicy |
from ding.utils import set_pkg_seed |
from ding.model import VAC |
from ding.framework import task, ding_init |
from ding.framework.context import OnlineRLContext |
from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \ |
gae_estimator, online_logger |
from easydict import EasyDict |
my_env_ppo_config = dict( |
exp_name='my_env_ppo_seed0', |
env=dict( |
collector_env_num=4, |
evaluator_env_num=4, |
n_evaluator_episode=4, |
stop_value=195, |
), |
policy=dict( |
cuda=True, |
action_space='discrete', |
model=dict( |
obs_shape=None, |
action_shape=2, |
action_space='discrete', |
critic_head_hidden_size=138, |
actor_head_hidden_size=138, |
), |
learn=dict( |
epoch_per_collect=2, |
batch_size=64, |
learning_rate=0.001, |
value_weight=0.5, |
entropy_weight=0.01, |
clip_ratio=0.2, |
learner=dict(hook=dict(save_ckpt_after_iter=100)), |
), |
collect=dict( |
n_sample=256, unroll_len=1, discount_factor=0.9, gae_lambda=0.95, collector=dict(transform_obs=True, ) |
), |
eval=dict(evaluator=dict(eval_freq=100, ), ), |
), |
) |
my_env_ppo_config = EasyDict(my_env_ppo_config) |
main_config = my_env_ppo_config |
my_env_ppo_create_config = dict( |
env_manager=dict(type='base'), |
policy=dict(type='ppo'), |
) |
my_env_ppo_create_config = EasyDict(my_env_ppo_create_config) |
create_config = my_env_ppo_create_config |
class MyEnv(gym.Env): |
def __init__(self, seq_len=5, feature_dim=10, image_size=(10, 10, 3)): |
super().__init__() |
self.action_space = spaces.Discrete(2) |
self.observation_space = spaces.Dict( |
( |
{ |
'key_0': spaces.Dict( |
{ |
'k1': spaces.Box(low=0, high=np.inf, shape=(1, ), dtype=np.float32), |
'k2': spaces.Box(low=-1, high=1, shape=(1, ), dtype=np.float32), |
} |
), |
'key_1': spaces.Box(low=-np.inf, high=np.inf, shape=(seq_len, feature_dim), dtype=np.float32), |
'key_2': spaces.Box(low=0, high=255, shape=image_size, dtype=np.uint8), |
'key_3': spaces.Box(low=0, high=np.array([np.inf, 3]), shape=(2, ), dtype=np.float32) |
} |
) |
) |
def reset(self): |
return self.observation_space.sample() |
def step(self, action): |
reward = np.random.uniform(low=0.0, high=1.0) |
done = False |
if np.random.uniform(low=0.0, high=1.0) > 0.7: |
done = True |
info = {} |
return self.observation_space.sample(), reward, done, info |
def ding_env_maker(): |
return DingEnvWrapper( |
MyEnv(), cfg={'env_wrapper': [ |
lambda env: EvalEpisodeReturnWrapper(env), |
]} |
) |
class Encoder(nn.Module): |
def __init__(self, feature_dim: int): |
super(Encoder, self).__init__() |
self.fc_net_1_k1 = nn.Sequential(nn.Linear(1, 8), nn.ReLU()) |
self.fc_net_1_k2 = nn.Sequential(nn.Linear(1, 8), nn.ReLU()) |
self.fc_net_1 = nn.Sequential(nn.Linear(16, 32), nn.ReLU()) |
""" |
Implementation of transformer_encoder refers to Vision Transformer (ViT) code: |
https://arxiv.org/abs/2010.11929 |
https://pytorch.org/vision/main/_modules/torchvision/models/vision_transformer.html |
""" |
self.class_token = nn.Parameter(torch.zeros(1, 1, feature_dim)) |
self.encoder_layer = nn.TransformerEncoderLayer(d_model=feature_dim, nhead=2, batch_first=True) |
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1) |
self.conv_net = nn.Sequential( |
nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(16, 32, kernel_size=3, padding=1), |
nn.ReLU() |
) |
self.conv_fc_net = nn.Sequential(nn.Flatten(), nn.Linear(3200, 64), nn.ReLU()) |
self.fc_net_2 = nn.Sequential(nn.Linear(2, 16), nn.ReLU(), nn.Linear(16, 32), nn.ReLU(), nn.Flatten()) |
def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: |
dict_input = inputs['key_0'] |
transformer_input = inputs['key_1'] |
conv_input = inputs['key_2'] |
fc_input = inputs['key_3'] |
B = fc_input.shape[0] |
dict_output = self.fc_net_1( |
torch.cat( |
[self.fc_net_1_k1(dict_input['k1'].unsqueeze(-1)), |
self.fc_net_1_k2(dict_input['k2'].unsqueeze(-1))], |
dim=1 |
) |
) |
batch_class_token = self.class_token.expand(B, -1, -1) |
transformer_output = self.transformer_encoder(torch.cat([batch_class_token, transformer_input], dim=1)) |
transformer_output = transformer_output[:, 0] |
conv_output = self.conv_fc_net(self.conv_net(conv_input.permute(0, 3, 1, 2))) |
fc_output = self.fc_net_2(fc_input) |
encoded_output = torch.cat([dict_output, transformer_output, conv_output, fc_output], dim=1) |
return encoded_output |
def main(): |
logging.getLogger().setLevel(logging.INFO) |
cfg = compile_config(main_config, create_cfg=create_config, auto=True) |
ding_init(cfg) |
with task.start(async_mode=False, ctx=OnlineRLContext()): |
collector_env = BaseEnvManagerV2( |
env_fn=[ding_env_maker for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager |
) |
evaluator_env = BaseEnvManagerV2( |
env_fn=[ding_env_maker for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager |
) |
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) |
encoder = Encoder(feature_dim=10) |
model = VAC(encoder=encoder, **cfg.policy.model) |
policy = PPOPolicy(cfg.policy, model=model) |
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) |
task.use(StepCollector(cfg, policy.collect_mode, collector_env)) |
task.use(gae_estimator(cfg, policy.collect_mode)) |
task.use(multistep_trainer(policy.learn_mode, log_freq=50)) |
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) |
task.use(online_logger(train_show_freq=3)) |
task.run() |
if __name__ == "__main__": |
main() |