File size: 2,230 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from typing import TYPE_CHECKING, Union, Callable, Optional
from ditk import logging
import numpy as np
import torch
from ding.utils import broadcast
from ding.framework import task
if TYPE_CHECKING:
from ding.framework import OnlineRLContext, OfflineRLContext
def termination_checker(max_env_step: Optional[int] = None, max_train_iter: Optional[int] = None) -> Callable:
if max_env_step is None:
max_env_step = np.inf
if max_train_iter is None:
max_train_iter = np.inf
def _check(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
# ">" is better than ">=" when taking logger result into consideration
assert hasattr(ctx, "env_step") or hasattr(ctx, "train_iter"), "Context must have env_step or train_iter"
if hasattr(ctx, "env_step") and ctx.env_step > max_env_step:
task.finish = True
logging.info('Exceeded maximum number of env_step({}), program is terminated'.format(ctx.env_step))
elif hasattr(ctx, "train_iter") and ctx.train_iter > max_train_iter:
task.finish = True
logging.info('Exceeded maximum number of train_iter({}), program is terminated'.format(ctx.train_iter))
return _check
def ddp_termination_checker(max_env_step=None, max_train_iter=None, rank=0):
if rank == 0:
if max_env_step is None:
max_env_step = np.inf
if max_train_iter is None:
max_train_iter = np.inf
def _check(ctx):
if rank == 0:
if ctx.env_step > max_env_step:
finish = torch.ones(1).long().cuda()
logging.info('Exceeded maximum number of env_step({}), program is terminated'.format(ctx.env_step))
elif ctx.train_iter > max_train_iter:
finish = torch.ones(1).long().cuda()
logging.info('Exceeded maximum number of train_iter({}), program is terminated'.format(ctx.train_iter))
else:
finish = torch.LongTensor([task.finish]).cuda()
else:
finish = torch.zeros(1).long().cuda()
# broadcast finish result to other DDP workers
broadcast(finish, 0)
task.finish = finish.cpu().bool().item()
return _check
|