File size: 2,230 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from typing import TYPE_CHECKING, Union, Callable, Optional
from ditk import logging
import numpy as np
import torch
from ding.utils import broadcast
from ding.framework import task

if TYPE_CHECKING:
    from ding.framework import OnlineRLContext, OfflineRLContext


def termination_checker(max_env_step: Optional[int] = None, max_train_iter: Optional[int] = None) -> Callable:
    if max_env_step is None:
        max_env_step = np.inf
    if max_train_iter is None:
        max_train_iter = np.inf

    def _check(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
        # ">" is better than ">=" when taking logger result into consideration
        assert hasattr(ctx, "env_step") or hasattr(ctx, "train_iter"), "Context must have env_step or train_iter"
        if hasattr(ctx, "env_step") and ctx.env_step > max_env_step:
            task.finish = True
            logging.info('Exceeded maximum number of env_step({}), program is terminated'.format(ctx.env_step))
        elif hasattr(ctx, "train_iter") and ctx.train_iter > max_train_iter:
            task.finish = True
            logging.info('Exceeded maximum number of train_iter({}), program is terminated'.format(ctx.train_iter))

    return _check


def ddp_termination_checker(max_env_step=None, max_train_iter=None, rank=0):
    if rank == 0:
        if max_env_step is None:
            max_env_step = np.inf
        if max_train_iter is None:
            max_train_iter = np.inf

    def _check(ctx):
        if rank == 0:
            if ctx.env_step > max_env_step:
                finish = torch.ones(1).long().cuda()
                logging.info('Exceeded maximum number of env_step({}), program is terminated'.format(ctx.env_step))
            elif ctx.train_iter > max_train_iter:
                finish = torch.ones(1).long().cuda()
                logging.info('Exceeded maximum number of train_iter({}), program is terminated'.format(ctx.train_iter))
            else:
                finish = torch.LongTensor([task.finish]).cuda()
        else:
            finish = torch.zeros(1).long().cuda()
        # broadcast finish result to other DDP workers
        broadcast(finish, 0)
        task.finish = finish.cpu().bool().item()

    return _check