|
from ding.utils import POLICY_REGISTRY |
|
from ding.rl_utils import get_epsilon_greedy_fn |
|
from .base_policy import CommandModePolicy |
|
|
|
from .dqn import DQNPolicy, DQNSTDIMPolicy |
|
from .mdqn import MDQNPolicy |
|
from .c51 import C51Policy |
|
from .qrdqn import QRDQNPolicy |
|
from .iqn import IQNPolicy |
|
from .fqf import FQFPolicy |
|
from .rainbow import RainbowDQNPolicy |
|
from .r2d2 import R2D2Policy |
|
from .r2d2_gtrxl import R2D2GTrXLPolicy |
|
from .r2d2_collect_traj import R2D2CollectTrajPolicy |
|
from .sqn import SQNPolicy |
|
from .ppo import PPOPolicy, PPOOffPolicy, PPOPGPolicy, PPOSTDIMPolicy |
|
from .offppo_collect_traj import OffPPOCollectTrajPolicy |
|
from .ppg import PPGPolicy, PPGOffPolicy |
|
from .pg import PGPolicy |
|
from .a2c import A2CPolicy |
|
from .impala import IMPALAPolicy |
|
from .ngu import NGUPolicy |
|
from .ddpg import DDPGPolicy |
|
from .td3 import TD3Policy |
|
from .td3_vae import TD3VAEPolicy |
|
from .td3_bc import TD3BCPolicy |
|
from .sac import SACPolicy, DiscreteSACPolicy, SQILSACPolicy |
|
from .mbpolicy.mbsac import MBSACPolicy, STEVESACPolicy |
|
from .mbpolicy.dreamer import DREAMERPolicy |
|
from .qmix import QMIXPolicy |
|
from .wqmix import WQMIXPolicy |
|
from .collaq import CollaQPolicy |
|
from .coma import COMAPolicy |
|
from .atoc import ATOCPolicy |
|
from .acer import ACERPolicy |
|
from .qtran import QTRANPolicy |
|
from .sql import SQLPolicy |
|
from .bc import BehaviourCloningPolicy |
|
from .ibc import IBCPolicy |
|
|
|
from .dqfd import DQFDPolicy |
|
from .r2d3 import R2D3Policy |
|
|
|
from .d4pg import D4PGPolicy |
|
from .cql import CQLPolicy, DiscreteCQLPolicy |
|
from .dt import DTPolicy |
|
from .pdqn import PDQNPolicy |
|
from .madqn import MADQNPolicy |
|
from .bdq import BDQPolicy |
|
from .bcq import BCQPolicy |
|
from .edac import EDACPolicy |
|
from .prompt_pg import PromptPGPolicy |
|
from .plan_diffuser import PDPolicy |
|
from .happo import HAPPOPolicy |
|
|
|
|
|
class EpsCommandModePolicy(CommandModePolicy): |
|
|
|
def _init_command(self) -> None: |
|
r""" |
|
Overview: |
|
Command mode init method. Called by ``self.__init__``. |
|
Set the eps_greedy rule according to the config for command |
|
""" |
|
eps_cfg = self._cfg.other.eps |
|
self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) |
|
|
|
def _get_setting_collect(self, command_info: dict) -> dict: |
|
r""" |
|
Overview: |
|
Collect mode setting information including eps |
|
Arguments: |
|
- command_info (:obj:`dict`): Dict type, including at least ['learner_train_iter', 'collector_envstep'] |
|
Returns: |
|
- collect_setting (:obj:`dict`): Including eps in collect mode. |
|
""" |
|
|
|
|
|
|
|
step = command_info['envstep'] |
|
return {'eps': self.epsilon_greedy(step)} |
|
|
|
def _get_setting_learn(self, command_info: dict) -> dict: |
|
return {} |
|
|
|
def _get_setting_eval(self, command_info: dict) -> dict: |
|
return {} |
|
|
|
|
|
class DummyCommandModePolicy(CommandModePolicy): |
|
|
|
def _init_command(self) -> None: |
|
pass |
|
|
|
def _get_setting_collect(self, command_info: dict) -> dict: |
|
return {} |
|
|
|
def _get_setting_learn(self, command_info: dict) -> dict: |
|
return {} |
|
|
|
def _get_setting_eval(self, command_info: dict) -> dict: |
|
return {} |
|
|
|
|
|
@POLICY_REGISTRY.register('bdq_command') |
|
class BDQCommandModePolicy(BDQPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('mdqn_command') |
|
class MDQNCommandModePolicy(MDQNPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('dqn_command') |
|
class DQNCommandModePolicy(DQNPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('dqn_stdim_command') |
|
class DQNSTDIMCommandModePolicy(DQNSTDIMPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('dqfd_command') |
|
class DQFDCommandModePolicy(DQFDPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('c51_command') |
|
class C51CommandModePolicy(C51Policy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('qrdqn_command') |
|
class QRDQNCommandModePolicy(QRDQNPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('iqn_command') |
|
class IQNCommandModePolicy(IQNPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('fqf_command') |
|
class FQFCommandModePolicy(FQFPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('rainbow_command') |
|
class RainbowDQNCommandModePolicy(RainbowDQNPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('r2d2_command') |
|
class R2D2CommandModePolicy(R2D2Policy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('r2d2_gtrxl_command') |
|
class R2D2GTrXLCommandModePolicy(R2D2GTrXLPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('r2d2_collect_traj_command') |
|
class R2D2CollectTrajCommandModePolicy(R2D2CollectTrajPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('r2d3_command') |
|
class R2D3CommandModePolicy(R2D3Policy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('sqn_command') |
|
class SQNCommandModePolicy(SQNPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('sql_command') |
|
class SQLCommandModePolicy(SQLPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('ppo_command') |
|
class PPOCommandModePolicy(PPOPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('happo_command') |
|
class HAPPOCommandModePolicy(HAPPOPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('ppo_stdim_command') |
|
class PPOSTDIMCommandModePolicy(PPOSTDIMPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('ppo_pg_command') |
|
class PPOPGCommandModePolicy(PPOPGPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('ppo_offpolicy_command') |
|
class PPOOffCommandModePolicy(PPOOffPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('offppo_collect_traj_command') |
|
class PPOOffCollectTrajCommandModePolicy(OffPPOCollectTrajPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('pg_command') |
|
class PGCommandModePolicy(PGPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('a2c_command') |
|
class A2CCommandModePolicy(A2CPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('impala_command') |
|
class IMPALACommandModePolicy(IMPALAPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('ppg_offpolicy_command') |
|
class PPGOffCommandModePolicy(PPGOffPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('ppg_command') |
|
class PPGCommandModePolicy(PPGPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('madqn_command') |
|
class MADQNCommandModePolicy(MADQNPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('ddpg_command') |
|
class DDPGCommandModePolicy(DDPGPolicy, CommandModePolicy): |
|
|
|
def _init_command(self) -> None: |
|
r""" |
|
Overview: |
|
Command mode init method. Called by ``self.__init__``. |
|
If hybrid action space, set the eps_greedy rule according to the config for command, |
|
otherwise, just a empty method |
|
""" |
|
if self._cfg.action_space == 'hybrid': |
|
eps_cfg = self._cfg.other.eps |
|
self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) |
|
|
|
def _get_setting_collect(self, command_info: dict) -> dict: |
|
r""" |
|
Overview: |
|
Collect mode setting information including eps when hybrid action space |
|
Arguments: |
|
- command_info (:obj:`dict`): Dict type, including at least ['learner_step', 'envstep'] |
|
Returns: |
|
- collect_setting (:obj:`dict`): Including eps in collect mode. |
|
""" |
|
if self._cfg.action_space == 'hybrid': |
|
|
|
|
|
|
|
step = command_info['envstep'] |
|
return {'eps': self.epsilon_greedy(step)} |
|
else: |
|
return {} |
|
|
|
def _get_setting_learn(self, command_info: dict) -> dict: |
|
return {} |
|
|
|
def _get_setting_eval(self, command_info: dict) -> dict: |
|
return {} |
|
|
|
|
|
@POLICY_REGISTRY.register('td3_command') |
|
class TD3CommandModePolicy(TD3Policy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('td3_vae_command') |
|
class TD3VAECommandModePolicy(TD3VAEPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('td3_bc_command') |
|
class TD3BCCommandModePolicy(TD3BCPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('sac_command') |
|
class SACCommandModePolicy(SACPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('mbsac_command') |
|
class MBSACCommandModePolicy(MBSACPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('stevesac_command') |
|
class STEVESACCommandModePolicy(STEVESACPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('dreamer_command') |
|
class DREAMERCommandModePolicy(DREAMERPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('cql_command') |
|
class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('discrete_cql_command') |
|
class DiscreteCQLCommandModePolicy(DiscreteCQLPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('dt_command') |
|
class DTCommandModePolicy(DTPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('qmix_command') |
|
class QMIXCommandModePolicy(QMIXPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('wqmix_command') |
|
class WQMIXCommandModePolicy(WQMIXPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('collaq_command') |
|
class CollaQCommandModePolicy(CollaQPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('coma_command') |
|
class COMACommandModePolicy(COMAPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('atoc_command') |
|
class ATOCCommandModePolicy(ATOCPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('acer_command') |
|
class ACERCommandModePolisy(ACERPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('qtran_command') |
|
class QTRANCommandModePolicy(QTRANPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('ngu_command') |
|
class NGUCommandModePolicy(NGUPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('d4pg_command') |
|
class D4PGCommandModePolicy(D4PGPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('pdqn_command') |
|
class PDQNCommandModePolicy(PDQNPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('discrete_sac_command') |
|
class DiscreteSACCommandModePolicy(DiscreteSACPolicy, EpsCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('sqil_sac_command') |
|
class SQILSACCommandModePolicy(SQILSACPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('ibc_command') |
|
class IBCCommandModePolicy(IBCPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('bcq_command') |
|
class BCQCommandModelPolicy(BCQPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('edac_command') |
|
class EDACCommandModelPolicy(EDACPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('pd_command') |
|
class PDCommandModelPolicy(PDPolicy, DummyCommandModePolicy): |
|
pass |
|
|
|
|
|
@POLICY_REGISTRY.register('bc_command') |
|
class BCCommandModePolicy(BehaviourCloningPolicy, DummyCommandModePolicy): |
|
|
|
def _init_command(self) -> None: |
|
r""" |
|
Overview: |
|
Command mode init method. Called by ``self.__init__``. |
|
Set the eps_greedy rule according to the config for command |
|
""" |
|
if self._cfg.continuous: |
|
noise_cfg = self._cfg.collect.noise_sigma |
|
self.epsilon_greedy = get_epsilon_greedy_fn(noise_cfg.start, noise_cfg.end, noise_cfg.decay, noise_cfg.type) |
|
else: |
|
eps_cfg = self._cfg.other.eps |
|
self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) |
|
|
|
def _get_setting_collect(self, command_info: dict) -> dict: |
|
r""" |
|
Overview: |
|
Collect mode setting information including eps |
|
Arguments: |
|
- command_info (:obj:`dict`): Dict type, including at least ['learner_train_iter', 'collector_envstep'] |
|
Returns: |
|
- collect_setting (:obj:`dict`): Including eps in collect mode. |
|
""" |
|
if self._cfg.continuous: |
|
|
|
step = command_info['learner_step'] |
|
return {'sigma': self.epsilon_greedy(step)} |
|
else: |
|
|
|
step = command_info['envstep'] |
|
return {'eps': self.epsilon_greedy(step)} |
|
|
|
def _get_setting_learn(self, command_info: dict) -> dict: |
|
return {} |
|
|
|
def _get_setting_eval(self, command_info: dict) -> dict: |
|
return {} |
|
|
|
|
|
@POLICY_REGISTRY.register('prompt_pg_command') |
|
class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy): |
|
pass |
|
|