Spaces:
Sleeping
Sleeping
| from ding.utils import POLICY_REGISTRY | |
| from ding.rl_utils import get_epsilon_greedy_fn | |
| from .base_policy import CommandModePolicy | |
| from .dqn import DQNPolicy, DQNSTDIMPolicy | |
| from .mdqn import MDQNPolicy | |
| from .c51 import C51Policy | |
| from .qrdqn import QRDQNPolicy | |
| from .iqn import IQNPolicy | |
| from .fqf import FQFPolicy | |
| from .rainbow import RainbowDQNPolicy | |
| from .r2d2 import R2D2Policy | |
| from .r2d2_gtrxl import R2D2GTrXLPolicy | |
| from .r2d2_collect_traj import R2D2CollectTrajPolicy | |
| from .sqn import SQNPolicy | |
| from .ppo import PPOPolicy, PPOOffPolicy, PPOPGPolicy, PPOSTDIMPolicy | |
| from .offppo_collect_traj import OffPPOCollectTrajPolicy | |
| from .ppg import PPGPolicy, PPGOffPolicy | |
| from .pg import PGPolicy | |
| from .a2c import A2CPolicy | |
| from .impala import IMPALAPolicy | |
| from .ngu import NGUPolicy | |
| from .ddpg import DDPGPolicy | |
| from .td3 import TD3Policy | |
| from .td3_vae import TD3VAEPolicy | |
| from .td3_bc import TD3BCPolicy | |
| from .sac import SACPolicy, DiscreteSACPolicy, SQILSACPolicy | |
| from .mbpolicy.mbsac import MBSACPolicy, STEVESACPolicy | |
| from .mbpolicy.dreamer import DREAMERPolicy | |
| from .qmix import QMIXPolicy | |
| from .wqmix import WQMIXPolicy | |
| from .collaq import CollaQPolicy | |
| from .coma import COMAPolicy | |
| from .atoc import ATOCPolicy | |
| from .acer import ACERPolicy | |
| from .qtran import QTRANPolicy | |
| from .sql import SQLPolicy | |
| from .bc import BehaviourCloningPolicy | |
| from .ibc import IBCPolicy | |
| from .dqfd import DQFDPolicy | |
| from .r2d3 import R2D3Policy | |
| from .d4pg import D4PGPolicy | |
| from .cql import CQLPolicy, DiscreteCQLPolicy | |
| from .dt import DTPolicy | |
| from .pdqn import PDQNPolicy | |
| from .madqn import MADQNPolicy | |
| from .bdq import BDQPolicy | |
| from .bcq import BCQPolicy | |
| from .edac import EDACPolicy | |
| from .prompt_pg import PromptPGPolicy | |
| from .plan_diffuser import PDPolicy | |
| from .happo import HAPPOPolicy | |
| class EpsCommandModePolicy(CommandModePolicy): | |
| def _init_command(self) -> None: | |
| r""" | |
| Overview: | |
| Command mode init method. Called by ``self.__init__``. | |
| Set the eps_greedy rule according to the config for command | |
| """ | |
| eps_cfg = self._cfg.other.eps | |
| self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) | |
| def _get_setting_collect(self, command_info: dict) -> dict: | |
| r""" | |
| Overview: | |
| Collect mode setting information including eps | |
| Arguments: | |
| - command_info (:obj:`dict`): Dict type, including at least ['learner_train_iter', 'collector_envstep'] | |
| Returns: | |
| - collect_setting (:obj:`dict`): Including eps in collect mode. | |
| """ | |
| # Decay according to `learner_train_iter` | |
| # step = command_info['learner_train_iter'] | |
| # Decay according to `envstep` | |
| step = command_info['envstep'] | |
| return {'eps': self.epsilon_greedy(step)} | |
| def _get_setting_learn(self, command_info: dict) -> dict: | |
| return {} | |
| def _get_setting_eval(self, command_info: dict) -> dict: | |
| return {} | |
| class DummyCommandModePolicy(CommandModePolicy): | |
| def _init_command(self) -> None: | |
| pass | |
| def _get_setting_collect(self, command_info: dict) -> dict: | |
| return {} | |
| def _get_setting_learn(self, command_info: dict) -> dict: | |
| return {} | |
| def _get_setting_eval(self, command_info: dict) -> dict: | |
| return {} | |
| class BDQCommandModePolicy(BDQPolicy, EpsCommandModePolicy): | |
| pass | |
| class MDQNCommandModePolicy(MDQNPolicy, EpsCommandModePolicy): | |
| pass | |
| class DQNCommandModePolicy(DQNPolicy, EpsCommandModePolicy): | |
| pass | |
| class DQNSTDIMCommandModePolicy(DQNSTDIMPolicy, EpsCommandModePolicy): | |
| pass | |
| class DQFDCommandModePolicy(DQFDPolicy, EpsCommandModePolicy): | |
| pass | |
| class C51CommandModePolicy(C51Policy, EpsCommandModePolicy): | |
| pass | |
| class QRDQNCommandModePolicy(QRDQNPolicy, EpsCommandModePolicy): | |
| pass | |
| class IQNCommandModePolicy(IQNPolicy, EpsCommandModePolicy): | |
| pass | |
| class FQFCommandModePolicy(FQFPolicy, EpsCommandModePolicy): | |
| pass | |
| class RainbowDQNCommandModePolicy(RainbowDQNPolicy, EpsCommandModePolicy): | |
| pass | |
| class R2D2CommandModePolicy(R2D2Policy, EpsCommandModePolicy): | |
| pass | |
| class R2D2GTrXLCommandModePolicy(R2D2GTrXLPolicy, EpsCommandModePolicy): | |
| pass | |
| class R2D2CollectTrajCommandModePolicy(R2D2CollectTrajPolicy, DummyCommandModePolicy): | |
| pass | |
| class R2D3CommandModePolicy(R2D3Policy, EpsCommandModePolicy): | |
| pass | |
| class SQNCommandModePolicy(SQNPolicy, DummyCommandModePolicy): | |
| pass | |
| class SQLCommandModePolicy(SQLPolicy, EpsCommandModePolicy): | |
| pass | |
| class PPOCommandModePolicy(PPOPolicy, DummyCommandModePolicy): | |
| pass | |
| class HAPPOCommandModePolicy(HAPPOPolicy, DummyCommandModePolicy): | |
| pass | |
| class PPOSTDIMCommandModePolicy(PPOSTDIMPolicy, DummyCommandModePolicy): | |
| pass | |
| class PPOPGCommandModePolicy(PPOPGPolicy, DummyCommandModePolicy): | |
| pass | |
| class PPOOffCommandModePolicy(PPOOffPolicy, DummyCommandModePolicy): | |
| pass | |
| class PPOOffCollectTrajCommandModePolicy(OffPPOCollectTrajPolicy, DummyCommandModePolicy): | |
| pass | |
| class PGCommandModePolicy(PGPolicy, DummyCommandModePolicy): | |
| pass | |
| class A2CCommandModePolicy(A2CPolicy, DummyCommandModePolicy): | |
| pass | |
| class IMPALACommandModePolicy(IMPALAPolicy, DummyCommandModePolicy): | |
| pass | |
| class PPGOffCommandModePolicy(PPGOffPolicy, DummyCommandModePolicy): | |
| pass | |
| class PPGCommandModePolicy(PPGPolicy, DummyCommandModePolicy): | |
| pass | |
| class MADQNCommandModePolicy(MADQNPolicy, EpsCommandModePolicy): | |
| pass | |
| class DDPGCommandModePolicy(DDPGPolicy, CommandModePolicy): | |
| def _init_command(self) -> None: | |
| r""" | |
| Overview: | |
| Command mode init method. Called by ``self.__init__``. | |
| If hybrid action space, set the eps_greedy rule according to the config for command, | |
| otherwise, just a empty method | |
| """ | |
| if self._cfg.action_space == 'hybrid': | |
| eps_cfg = self._cfg.other.eps | |
| self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) | |
| def _get_setting_collect(self, command_info: dict) -> dict: | |
| r""" | |
| Overview: | |
| Collect mode setting information including eps when hybrid action space | |
| Arguments: | |
| - command_info (:obj:`dict`): Dict type, including at least ['learner_step', 'envstep'] | |
| Returns: | |
| - collect_setting (:obj:`dict`): Including eps in collect mode. | |
| """ | |
| if self._cfg.action_space == 'hybrid': | |
| # Decay according to `learner_step` | |
| # step = command_info['learner_step'] | |
| # Decay according to `envstep` | |
| step = command_info['envstep'] | |
| return {'eps': self.epsilon_greedy(step)} | |
| else: | |
| return {} | |
| def _get_setting_learn(self, command_info: dict) -> dict: | |
| return {} | |
| def _get_setting_eval(self, command_info: dict) -> dict: | |
| return {} | |
| class TD3CommandModePolicy(TD3Policy, DummyCommandModePolicy): | |
| pass | |
| class TD3VAECommandModePolicy(TD3VAEPolicy, DummyCommandModePolicy): | |
| pass | |
| class TD3BCCommandModePolicy(TD3BCPolicy, DummyCommandModePolicy): | |
| pass | |
| class SACCommandModePolicy(SACPolicy, DummyCommandModePolicy): | |
| pass | |
| class MBSACCommandModePolicy(MBSACPolicy, DummyCommandModePolicy): | |
| pass | |
| class STEVESACCommandModePolicy(STEVESACPolicy, DummyCommandModePolicy): | |
| pass | |
| class DREAMERCommandModePolicy(DREAMERPolicy, DummyCommandModePolicy): | |
| pass | |
| class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy): | |
| pass | |
| class DiscreteCQLCommandModePolicy(DiscreteCQLPolicy, EpsCommandModePolicy): | |
| pass | |
| class DTCommandModePolicy(DTPolicy, DummyCommandModePolicy): | |
| pass | |
| class QMIXCommandModePolicy(QMIXPolicy, EpsCommandModePolicy): | |
| pass | |
| class WQMIXCommandModePolicy(WQMIXPolicy, EpsCommandModePolicy): | |
| pass | |
| class CollaQCommandModePolicy(CollaQPolicy, EpsCommandModePolicy): | |
| pass | |
| class COMACommandModePolicy(COMAPolicy, EpsCommandModePolicy): | |
| pass | |
| class ATOCCommandModePolicy(ATOCPolicy, DummyCommandModePolicy): | |
| pass | |
| class ACERCommandModePolisy(ACERPolicy, DummyCommandModePolicy): | |
| pass | |
| class QTRANCommandModePolicy(QTRANPolicy, EpsCommandModePolicy): | |
| pass | |
| class NGUCommandModePolicy(NGUPolicy, EpsCommandModePolicy): | |
| pass | |
| class D4PGCommandModePolicy(D4PGPolicy, DummyCommandModePolicy): | |
| pass | |
| class PDQNCommandModePolicy(PDQNPolicy, EpsCommandModePolicy): | |
| pass | |
| class DiscreteSACCommandModePolicy(DiscreteSACPolicy, EpsCommandModePolicy): | |
| pass | |
| class SQILSACCommandModePolicy(SQILSACPolicy, DummyCommandModePolicy): | |
| pass | |
| class IBCCommandModePolicy(IBCPolicy, DummyCommandModePolicy): | |
| pass | |
| class BCQCommandModelPolicy(BCQPolicy, DummyCommandModePolicy): | |
| pass | |
| class EDACCommandModelPolicy(EDACPolicy, DummyCommandModePolicy): | |
| pass | |
| class PDCommandModelPolicy(PDPolicy, DummyCommandModePolicy): | |
| pass | |
| class BCCommandModePolicy(BehaviourCloningPolicy, DummyCommandModePolicy): | |
| def _init_command(self) -> None: | |
| r""" | |
| Overview: | |
| Command mode init method. Called by ``self.__init__``. | |
| Set the eps_greedy rule according to the config for command | |
| """ | |
| if self._cfg.continuous: | |
| noise_cfg = self._cfg.collect.noise_sigma | |
| self.epsilon_greedy = get_epsilon_greedy_fn(noise_cfg.start, noise_cfg.end, noise_cfg.decay, noise_cfg.type) | |
| else: | |
| eps_cfg = self._cfg.other.eps | |
| self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) | |
| def _get_setting_collect(self, command_info: dict) -> dict: | |
| r""" | |
| Overview: | |
| Collect mode setting information including eps | |
| Arguments: | |
| - command_info (:obj:`dict`): Dict type, including at least ['learner_train_iter', 'collector_envstep'] | |
| Returns: | |
| - collect_setting (:obj:`dict`): Including eps in collect mode. | |
| """ | |
| if self._cfg.continuous: | |
| # Decay according to `learner_step` | |
| step = command_info['learner_step'] | |
| return {'sigma': self.epsilon_greedy(step)} | |
| else: | |
| # Decay according to `envstep` | |
| step = command_info['envstep'] | |
| return {'eps': self.epsilon_greedy(step)} | |
| def _get_setting_learn(self, command_info: dict) -> dict: | |
| return {} | |
| def _get_setting_eval(self, command_info: dict) -> dict: | |
| return {} | |
| class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy): | |
| pass | |