|
import ding.config |
|
from .a2c import A2CAgent |
|
from .c51 import C51Agent |
|
from .ddpg import DDPGAgent |
|
from .dqn import DQNAgent |
|
from .pg import PGAgent |
|
from .ppof import PPOF |
|
from .ppo_offpolicy import PPOOffPolicyAgent |
|
from .sac import SACAgent |
|
from .sql import SQLAgent |
|
from .td3 import TD3Agent |
|
|
|
supported_algo = dict( |
|
A2C=A2CAgent, |
|
C51=C51Agent, |
|
DDPG=DDPGAgent, |
|
DQN=DQNAgent, |
|
PG=PGAgent, |
|
PPOF=PPOF, |
|
PPOOffPolicy=PPOOffPolicyAgent, |
|
SAC=SACAgent, |
|
SQL=SQLAgent, |
|
TD3=TD3Agent, |
|
) |
|
|
|
supported_algo_list = list(supported_algo.keys()) |
|
|
|
|
|
def env_supported(algo: str = None) -> list: |
|
""" |
|
return list of the envs that supported by di-engine. |
|
""" |
|
|
|
if algo is not None: |
|
if algo.upper() == "A2C": |
|
return list(ding.config.example.A2C.supported_env.keys()) |
|
elif algo.upper() == "C51": |
|
return list(ding.config.example.C51.supported_env.keys()) |
|
elif algo.upper() == "DDPG": |
|
return list(ding.config.example.DDPG.supported_env.keys()) |
|
elif algo.upper() == "DQN": |
|
return list(ding.config.example.DQN.supported_env.keys()) |
|
elif algo.upper() == "PG": |
|
return list(ding.config.example.PG.supported_env.keys()) |
|
elif algo.upper() == "PPOF": |
|
return list(ding.config.example.PPOF.supported_env.keys()) |
|
elif algo.upper() == "PPOOFFPOLICY": |
|
return list(ding.config.example.PPOOffPolicy.supported_env.keys()) |
|
elif algo.upper() == "SAC": |
|
return list(ding.config.example.SAC.supported_env.keys()) |
|
elif algo.upper() == "SQL": |
|
return list(ding.config.example.SQL.supported_env.keys()) |
|
elif algo.upper() == "TD3": |
|
return list(ding.config.example.TD3.supported_env.keys()) |
|
else: |
|
raise ValueError("The algo {} is not supported by di-engine.".format(algo)) |
|
else: |
|
supported_env = set() |
|
supported_env.update(ding.config.example.A2C.supported_env.keys()) |
|
supported_env.update(ding.config.example.C51.supported_env.keys()) |
|
supported_env.update(ding.config.example.DDPG.supported_env.keys()) |
|
supported_env.update(ding.config.example.DQN.supported_env.keys()) |
|
supported_env.update(ding.config.example.PG.supported_env.keys()) |
|
supported_env.update(ding.config.example.PPOF.supported_env.keys()) |
|
supported_env.update(ding.config.example.PPOOffPolicy.supported_env.keys()) |
|
supported_env.update(ding.config.example.SAC.supported_env.keys()) |
|
supported_env.update(ding.config.example.SQL.supported_env.keys()) |
|
supported_env.update(ding.config.example.TD3.supported_env.keys()) |
|
|
|
return list(supported_env) |
|
|
|
|
|
supported_env = env_supported() |
|
|
|
|
|
def algo_supported(env_id: str = None) -> list: |
|
""" |
|
return list of the algos that supported by di-engine. |
|
""" |
|
if env_id is not None: |
|
algo = [] |
|
if env_id.upper() in [item.upper() for item in ding.config.example.A2C.supported_env.keys()]: |
|
algo.append("A2C") |
|
if env_id.upper() in [item.upper() for item in ding.config.example.C51.supported_env.keys()]: |
|
algo.append("C51") |
|
if env_id.upper() in [item.upper() for item in ding.config.example.DDPG.supported_env.keys()]: |
|
algo.append("DDPG") |
|
if env_id.upper() in [item.upper() for item in ding.config.example.DQN.supported_env.keys()]: |
|
algo.append("DQN") |
|
if env_id.upper() in [item.upper() for item in ding.config.example.PG.supported_env.keys()]: |
|
algo.append("PG") |
|
if env_id.upper() in [item.upper() for item in ding.config.example.PPOF.supported_env.keys()]: |
|
algo.append("PPOF") |
|
if env_id.upper() in [item.upper() for item in ding.config.example.PPOOffPolicy.supported_env.keys()]: |
|
algo.append("PPOOffPolicy") |
|
if env_id.upper() in [item.upper() for item in ding.config.example.SAC.supported_env.keys()]: |
|
algo.append("SAC") |
|
if env_id.upper() in [item.upper() for item in ding.config.example.SQL.supported_env.keys()]: |
|
algo.append("SQL") |
|
if env_id.upper() in [item.upper() for item in ding.config.example.TD3.supported_env.keys()]: |
|
algo.append("TD3") |
|
|
|
if len(algo) == 0: |
|
raise ValueError("The env {} is not supported by di-engine.".format(env_id)) |
|
return algo |
|
else: |
|
return supported_algo_list |
|
|
|
|
|
def is_supported(env_id: str = None, algo: str = None) -> bool: |
|
""" |
|
Check if the env-algo pair is supported by di-engine. |
|
""" |
|
if env_id is not None and env_id.upper() in [item.upper() for item in supported_env.keys()]: |
|
if algo is not None and algo.upper() in supported_algo_list: |
|
if env_id.upper() in env_supported(algo): |
|
return True |
|
else: |
|
return False |
|
elif algo is None: |
|
return True |
|
else: |
|
return False |
|
elif env_id is None: |
|
if algo is not None and algo.upper() in supported_algo_list: |
|
return True |
|
elif algo is None: |
|
raise ValueError("Please specify the env or algo.") |
|
else: |
|
return False |
|
else: |
|
return False |
|
|