File size: 4,979 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from abc import ABC, abstractmethod, abstractproperty
from easydict import EasyDict

from ding.utils import EasyTimer, import_module, get_task_uid, dist_init, dist_finalize, COMM_LEARNER_REGISTRY
from ding.policy import create_policy
from ding.worker.learner import create_learner


class BaseCommLearner(ABC):
    """
    Overview:
        Abstract baseclass for CommLearner.
    Interfaces:
        __init__, send_policy, get_data, send_learn_info, start, close
    Property:
        hooks4call
    """

    def __init__(self, cfg: 'EasyDict') -> None:  # noqa
        """
        Overview:
            Initialization method.
        Arguments:
            - cfg (:obj:`EasyDict`): Config dict
        """
        self._cfg = cfg
        self._learner_uid = get_task_uid()
        self._timer = EasyTimer()
        if cfg.multi_gpu:
            self._rank, self._world_size = dist_init()
        else:
            self._rank, self._world_size = 0, 1
        self._multi_gpu = cfg.multi_gpu
        self._end_flag = True

    @abstractmethod
    def send_policy(self, state_dict: dict) -> None:
        """
        Overview:
            Save learner's policy in corresponding path.
            Will be registered in base learner.
        Arguments:
            - state_dict (:obj:`dict`): State dict of the runtime policy.
        """
        raise NotImplementedError

    @abstractmethod
    def get_data(self, batch_size: int) -> list:
        """
        Overview:
            Get batched meta data from coordinator.
            Will be registered in base learner.
        Arguments:
            - batch_size (:obj:`int`): Batch size.
        Returns:
            - stepdata (:obj:`list`): A list of training data, each element is one trajectory.
        """
        raise NotImplementedError

    @abstractmethod
    def send_learn_info(self, learn_info: dict) -> None:
        """
        Overview:
            Send learn info to coordinator.
            Will be registered in base learner.
        Arguments:
            - learn_info (:obj:`dict`): Learn info in dict type.
        """
        raise NotImplementedError

    def start(self) -> None:
        """
        Overview:
            Start comm learner.
        """
        self._end_flag = False

    def close(self) -> None:
        """
        Overview:
            Close comm learner.
        """
        self._end_flag = True
        if self._multi_gpu:
            dist_finalize()

    @abstractproperty
    def hooks4call(self) -> list:
        """
        Returns:
            - hooks (:obj:`list`): The hooks which comm learner has. Will be registered in learner as well.
        """
        raise NotImplementedError

    def _create_learner(self, task_info: dict) -> 'BaseLearner':  # noqa
        """
        Overview:
            Receive ``task_info`` passed from coordinator and create a learner.
        Arguments:
            - task_info (:obj:`dict`): Task info dict from coordinator. Should be like \
                {"learner_cfg": xxx, "policy": xxx}.
        Returns:
            - learner (:obj:`BaseLearner`): Created base learner.

        .. note::
            Three methods('get_data', 'send_policy', 'send_learn_info'), dataloader and policy are set.
            The reason why they are set here rather than base learner is that, they highly depend on the specific task.
            Only after task info is passed from coordinator to comm learner through learner slave, can they be
            clarified and initialized.
        """
        # Prepare learner config and instantiate a learner object.
        learner_cfg = EasyDict(task_info['learner_cfg'])
        learner = create_learner(learner_cfg, dist_info=[self._rank, self._world_size], exp_name=learner_cfg.exp_name)
        # Set 3 methods and dataloader in created learner that are necessary in parallel setting.
        for item in ['get_data', 'send_policy', 'send_learn_info']:
            setattr(learner, item, getattr(self, item))
        # Set policy in created learner.
        policy_cfg = task_info['policy']
        policy_cfg = EasyDict(policy_cfg)
        learner.policy = create_policy(policy_cfg, enable_field=['learn']).learn_mode
        learner.setup_dataloader()
        return learner


def create_comm_learner(cfg: EasyDict) -> BaseCommLearner:
    """
    Overview:
        Given the key(comm_learner_name), create a new comm learner instance if in comm_map's values,
        or raise an KeyError. In other words, a derived comm learner must first register,
        then can call ``create_comm_learner`` to get the instance.
    Arguments:
        - cfg (:obj:`dict`): Learner config. Necessary keys: [import_names, comm_learner_type].
    Returns:
        - learner (:obj:`BaseCommLearner`): The created new comm learner, should be an instance of one of \
            comm_map's values.
    """
    import_module(cfg.get('import_names', []))
    return COMM_LEARNER_REGISTRY.build(cfg.type, cfg=cfg)