|
from typing import TYPE_CHECKING, Callable |
|
from easydict import EasyDict |
|
from ding.rl_utils import get_epsilon_greedy_fn |
|
from ding.framework import task |
|
|
|
if TYPE_CHECKING: |
|
from ding.framework import OnlineRLContext |
|
|
|
|
|
def eps_greedy_handler(cfg: EasyDict) -> Callable: |
|
""" |
|
Overview: |
|
The middleware that computes epsilon value according to the env_step. |
|
Arguments: |
|
- cfg (:obj:`EasyDict`): Config. |
|
""" |
|
if task.router.is_active and not task.has_role(task.role.COLLECTOR): |
|
return task.void() |
|
|
|
eps_cfg = cfg.policy.other.eps |
|
handle = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) |
|
|
|
def _eps_greedy(ctx: "OnlineRLContext"): |
|
""" |
|
Input of ctx: |
|
- env_step (:obj:`int`): The env steps count. |
|
Output of ctx: |
|
- collect_kwargs['eps'] (:obj:`float`): The eps conditioned on env_step and cfg. |
|
""" |
|
|
|
ctx.collect_kwargs['eps'] = handle(ctx.env_step) |
|
yield |
|
try: |
|
ctx.collect_kwargs.pop('eps') |
|
except: |
|
pass |
|
|
|
return _eps_greedy |
|
|
|
|
|
def eps_greedy_masker(): |
|
""" |
|
Overview: |
|
The middleware that returns masked epsilon value and stop generating \ |
|
actions by the e_greedy method. |
|
""" |
|
|
|
def _masker(ctx: "OnlineRLContext"): |
|
""" |
|
Output of ctx: |
|
- collect_kwargs['eps'] (:obj:`float`): The masked eps value, default to -1. |
|
""" |
|
|
|
ctx.collect_kwargs['eps'] = -1 |
|
|
|
return _masker |
|
|