Spaces:
Sleeping
Sleeping
from ding.utils import POLICY_REGISTRY | |
from ding.rl_utils import get_epsilon_greedy_fn | |
from .base_policy import CommandModePolicy | |
from .dqn import DQNPolicy, DQNSTDIMPolicy | |
from .mdqn import MDQNPolicy | |
from .c51 import C51Policy | |
from .qrdqn import QRDQNPolicy | |
from .iqn import IQNPolicy | |
from .fqf import FQFPolicy | |
from .rainbow import RainbowDQNPolicy | |
from .r2d2 import R2D2Policy | |
from .r2d2_gtrxl import R2D2GTrXLPolicy | |
from .r2d2_collect_traj import R2D2CollectTrajPolicy | |
from .sqn import SQNPolicy | |
from .ppo import PPOPolicy, PPOOffPolicy, PPOPGPolicy, PPOSTDIMPolicy | |
from .offppo_collect_traj import OffPPOCollectTrajPolicy | |
from .ppg import PPGPolicy, PPGOffPolicy | |
from .pg import PGPolicy | |
from .a2c import A2CPolicy | |
from .impala import IMPALAPolicy | |
from .ngu import NGUPolicy | |
from .ddpg import DDPGPolicy | |
from .td3 import TD3Policy | |
from .td3_vae import TD3VAEPolicy | |
from .td3_bc import TD3BCPolicy | |
from .sac import SACPolicy, DiscreteSACPolicy, SQILSACPolicy | |
from .mbpolicy.mbsac import MBSACPolicy, STEVESACPolicy | |
from .mbpolicy.dreamer import DREAMERPolicy | |
from .qmix import QMIXPolicy | |
from .wqmix import WQMIXPolicy | |
from .collaq import CollaQPolicy | |
from .coma import COMAPolicy | |
from .atoc import ATOCPolicy | |
from .acer import ACERPolicy | |
from .qtran import QTRANPolicy | |
from .sql import SQLPolicy | |
from .bc import BehaviourCloningPolicy | |
from .ibc import IBCPolicy | |
from .dqfd import DQFDPolicy | |
from .r2d3 import R2D3Policy | |
from .d4pg import D4PGPolicy | |
from .cql import CQLPolicy, DiscreteCQLPolicy | |
from .dt import DTPolicy | |
from .pdqn import PDQNPolicy | |
from .madqn import MADQNPolicy | |
from .bdq import BDQPolicy | |
from .bcq import BCQPolicy | |
from .edac import EDACPolicy | |
from .prompt_pg import PromptPGPolicy | |
from .plan_diffuser import PDPolicy | |
from .happo import HAPPOPolicy | |
class EpsCommandModePolicy(CommandModePolicy): | |
def _init_command(self) -> None: | |
r""" | |
Overview: | |
Command mode init method. Called by ``self.__init__``. | |
Set the eps_greedy rule according to the config for command | |
""" | |
eps_cfg = self._cfg.other.eps | |
self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) | |
def _get_setting_collect(self, command_info: dict) -> dict: | |
r""" | |
Overview: | |
Collect mode setting information including eps | |
Arguments: | |
- command_info (:obj:`dict`): Dict type, including at least ['learner_train_iter', 'collector_envstep'] | |
Returns: | |
- collect_setting (:obj:`dict`): Including eps in collect mode. | |
""" | |
# Decay according to `learner_train_iter` | |
# step = command_info['learner_train_iter'] | |
# Decay according to `envstep` | |
step = command_info['envstep'] | |
return {'eps': self.epsilon_greedy(step)} | |
def _get_setting_learn(self, command_info: dict) -> dict: | |
return {} | |
def _get_setting_eval(self, command_info: dict) -> dict: | |
return {} | |
class DummyCommandModePolicy(CommandModePolicy): | |
def _init_command(self) -> None: | |
pass | |
def _get_setting_collect(self, command_info: dict) -> dict: | |
return {} | |
def _get_setting_learn(self, command_info: dict) -> dict: | |
return {} | |
def _get_setting_eval(self, command_info: dict) -> dict: | |
return {} | |
class BDQCommandModePolicy(BDQPolicy, EpsCommandModePolicy): | |
pass | |
class MDQNCommandModePolicy(MDQNPolicy, EpsCommandModePolicy): | |
pass | |
class DQNCommandModePolicy(DQNPolicy, EpsCommandModePolicy): | |
pass | |
class DQNSTDIMCommandModePolicy(DQNSTDIMPolicy, EpsCommandModePolicy): | |
pass | |
class DQFDCommandModePolicy(DQFDPolicy, EpsCommandModePolicy): | |
pass | |
class C51CommandModePolicy(C51Policy, EpsCommandModePolicy): | |
pass | |
class QRDQNCommandModePolicy(QRDQNPolicy, EpsCommandModePolicy): | |
pass | |
class IQNCommandModePolicy(IQNPolicy, EpsCommandModePolicy): | |
pass | |
class FQFCommandModePolicy(FQFPolicy, EpsCommandModePolicy): | |
pass | |
class RainbowDQNCommandModePolicy(RainbowDQNPolicy, EpsCommandModePolicy): | |
pass | |
class R2D2CommandModePolicy(R2D2Policy, EpsCommandModePolicy): | |
pass | |
class R2D2GTrXLCommandModePolicy(R2D2GTrXLPolicy, EpsCommandModePolicy): | |
pass | |
class R2D2CollectTrajCommandModePolicy(R2D2CollectTrajPolicy, DummyCommandModePolicy): | |
pass | |
class R2D3CommandModePolicy(R2D3Policy, EpsCommandModePolicy): | |
pass | |
class SQNCommandModePolicy(SQNPolicy, DummyCommandModePolicy): | |
pass | |
class SQLCommandModePolicy(SQLPolicy, EpsCommandModePolicy): | |
pass | |
class PPOCommandModePolicy(PPOPolicy, DummyCommandModePolicy): | |
pass | |
class HAPPOCommandModePolicy(HAPPOPolicy, DummyCommandModePolicy): | |
pass | |
class PPOSTDIMCommandModePolicy(PPOSTDIMPolicy, DummyCommandModePolicy): | |
pass | |
class PPOPGCommandModePolicy(PPOPGPolicy, DummyCommandModePolicy): | |
pass | |
class PPOOffCommandModePolicy(PPOOffPolicy, DummyCommandModePolicy): | |
pass | |
class PPOOffCollectTrajCommandModePolicy(OffPPOCollectTrajPolicy, DummyCommandModePolicy): | |
pass | |
class PGCommandModePolicy(PGPolicy, DummyCommandModePolicy): | |
pass | |
class A2CCommandModePolicy(A2CPolicy, DummyCommandModePolicy): | |
pass | |
class IMPALACommandModePolicy(IMPALAPolicy, DummyCommandModePolicy): | |
pass | |
class PPGOffCommandModePolicy(PPGOffPolicy, DummyCommandModePolicy): | |
pass | |
class PPGCommandModePolicy(PPGPolicy, DummyCommandModePolicy): | |
pass | |
class MADQNCommandModePolicy(MADQNPolicy, EpsCommandModePolicy): | |
pass | |
class DDPGCommandModePolicy(DDPGPolicy, CommandModePolicy): | |
def _init_command(self) -> None: | |
r""" | |
Overview: | |
Command mode init method. Called by ``self.__init__``. | |
If hybrid action space, set the eps_greedy rule according to the config for command, | |
otherwise, just a empty method | |
""" | |
if self._cfg.action_space == 'hybrid': | |
eps_cfg = self._cfg.other.eps | |
self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) | |
def _get_setting_collect(self, command_info: dict) -> dict: | |
r""" | |
Overview: | |
Collect mode setting information including eps when hybrid action space | |
Arguments: | |
- command_info (:obj:`dict`): Dict type, including at least ['learner_step', 'envstep'] | |
Returns: | |
- collect_setting (:obj:`dict`): Including eps in collect mode. | |
""" | |
if self._cfg.action_space == 'hybrid': | |
# Decay according to `learner_step` | |
# step = command_info['learner_step'] | |
# Decay according to `envstep` | |
step = command_info['envstep'] | |
return {'eps': self.epsilon_greedy(step)} | |
else: | |
return {} | |
def _get_setting_learn(self, command_info: dict) -> dict: | |
return {} | |
def _get_setting_eval(self, command_info: dict) -> dict: | |
return {} | |
class TD3CommandModePolicy(TD3Policy, DummyCommandModePolicy): | |
pass | |
class TD3VAECommandModePolicy(TD3VAEPolicy, DummyCommandModePolicy): | |
pass | |
class TD3BCCommandModePolicy(TD3BCPolicy, DummyCommandModePolicy): | |
pass | |
class SACCommandModePolicy(SACPolicy, DummyCommandModePolicy): | |
pass | |
class MBSACCommandModePolicy(MBSACPolicy, DummyCommandModePolicy): | |
pass | |
class STEVESACCommandModePolicy(STEVESACPolicy, DummyCommandModePolicy): | |
pass | |
class DREAMERCommandModePolicy(DREAMERPolicy, DummyCommandModePolicy): | |
pass | |
class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy): | |
pass | |
class DiscreteCQLCommandModePolicy(DiscreteCQLPolicy, EpsCommandModePolicy): | |
pass | |
class DTCommandModePolicy(DTPolicy, DummyCommandModePolicy): | |
pass | |
class QMIXCommandModePolicy(QMIXPolicy, EpsCommandModePolicy): | |
pass | |
class WQMIXCommandModePolicy(WQMIXPolicy, EpsCommandModePolicy): | |
pass | |
class CollaQCommandModePolicy(CollaQPolicy, EpsCommandModePolicy): | |
pass | |
class COMACommandModePolicy(COMAPolicy, EpsCommandModePolicy): | |
pass | |
class ATOCCommandModePolicy(ATOCPolicy, DummyCommandModePolicy): | |
pass | |
class ACERCommandModePolisy(ACERPolicy, DummyCommandModePolicy): | |
pass | |
class QTRANCommandModePolicy(QTRANPolicy, EpsCommandModePolicy): | |
pass | |
class NGUCommandModePolicy(NGUPolicy, EpsCommandModePolicy): | |
pass | |
class D4PGCommandModePolicy(D4PGPolicy, DummyCommandModePolicy): | |
pass | |
class PDQNCommandModePolicy(PDQNPolicy, EpsCommandModePolicy): | |
pass | |
class DiscreteSACCommandModePolicy(DiscreteSACPolicy, EpsCommandModePolicy): | |
pass | |
class SQILSACCommandModePolicy(SQILSACPolicy, DummyCommandModePolicy): | |
pass | |
class IBCCommandModePolicy(IBCPolicy, DummyCommandModePolicy): | |
pass | |
class BCQCommandModelPolicy(BCQPolicy, DummyCommandModePolicy): | |
pass | |
class EDACCommandModelPolicy(EDACPolicy, DummyCommandModePolicy): | |
pass | |
class PDCommandModelPolicy(PDPolicy, DummyCommandModePolicy): | |
pass | |
class BCCommandModePolicy(BehaviourCloningPolicy, DummyCommandModePolicy): | |
def _init_command(self) -> None: | |
r""" | |
Overview: | |
Command mode init method. Called by ``self.__init__``. | |
Set the eps_greedy rule according to the config for command | |
""" | |
if self._cfg.continuous: | |
noise_cfg = self._cfg.collect.noise_sigma | |
self.epsilon_greedy = get_epsilon_greedy_fn(noise_cfg.start, noise_cfg.end, noise_cfg.decay, noise_cfg.type) | |
else: | |
eps_cfg = self._cfg.other.eps | |
self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) | |
def _get_setting_collect(self, command_info: dict) -> dict: | |
r""" | |
Overview: | |
Collect mode setting information including eps | |
Arguments: | |
- command_info (:obj:`dict`): Dict type, including at least ['learner_train_iter', 'collector_envstep'] | |
Returns: | |
- collect_setting (:obj:`dict`): Including eps in collect mode. | |
""" | |
if self._cfg.continuous: | |
# Decay according to `learner_step` | |
step = command_info['learner_step'] | |
return {'sigma': self.epsilon_greedy(step)} | |
else: | |
# Decay according to `envstep` | |
step = command_info['envstep'] | |
return {'eps': self.epsilon_greedy(step)} | |
def _get_setting_learn(self, command_info: dict) -> dict: | |
return {} | |
def _get_setting_eval(self, command_info: dict) -> dict: | |
return {} | |
class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy): | |
pass | |