File size: 3,684 Bytes
80ebcb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from contextlib import contextmanager
from typing import Any, Dict, List, Optional

import torch

from ..trackers import TrackerType, initialize_trackers


class BaseParallelBackend:
    r"""
    Base class that contains properties and methods that should be implemented by different parallel backends.
    """

    def apply_ddp(self, *args, **kwargs) -> torch.nn.Module:
        raise NotImplementedError("Method `apply_ddp` must be implemented by subclass.")

    def prepare_dataset(self, *args, **kwargs) -> Any:
        raise NotImplementedError("Method `prepare_dataset` must be implemented by subclass.")

    def prepare_dataloader(self, *args, **kwargs) -> Any:
        raise NotImplementedError("Method `prepare_dataloader` must be implemented by subclass.")

    def prepare_optimizer(self, *args, **kwargs) -> Any:
        raise NotImplementedError("Method `prepare_optimizer` must be implemented by subclass.")

    def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
        raise NotImplementedError("Method `get_mesh` must be implemented by subclass.")

    def initialize_trackers(
        self, trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str
    ) -> TrackerType:
        self.tracker = None
        if self.is_main_process:
            self.tracker = initialize_trackers(trackers, experiment_name, config, log_dir)

    def log(self, metrics: Dict[str, Any], step: int) -> None:
        if self.is_main_process:
            self.tracker.log(metrics, step)

    def wait_for_everyone(self):
        raise NotImplementedError("Method `wait_for_everyone` must be implemented by subclass.")

    @contextmanager
    def main_process_first(self):
        raise NotImplementedError("Method `main_process_first` must be implemented by subclass.")

    def destroy(self):
        raise NotImplementedError("Method `destroy` must be implemented by subclass.")

    @property
    def world_size(self):
        raise NotImplementedError("Method `world_size` must be implemented by subclass.")

    @property
    def rank(self):
        raise NotImplementedError("Method `rank` must be implemented by subclass.")

    @property
    def local_rank(self):
        raise NotImplementedError("Method `local_rank` must be implemented by subclass.")

    @property
    def is_main_process(self):
        raise NotImplementedError("Method `is_main_process` must be implemented by subclass.")

    @property
    def is_local_main_process(self):
        raise NotImplementedError("Method `is_local_main_process` must be implemented by subclass.")

    @property
    def device(self):
        raise NotImplementedError("Method `device` must be implemented by subclass.")

    @property
    def pipeline_parallel_enabled(self):
        raise NotImplementedError("Property `pipeline_parallel_enabled` must be implemented by subclass.")

    @property
    def data_parallel_enabled(self):
        raise NotImplementedError("Property `data_parallel_enabled` must be implemented by subclass.")

    @property
    def data_replication_enabled(self):
        raise NotImplementedError("Property `data_replication_enabled` must be implemented by subclass.")

    @property
    def data_sharding_enabled(self):
        raise NotImplementedError("Property `data_sharding_enabled` must be implemented by subclass.")

    @property
    def context_parallel_enabled(self):
        raise NotImplementedError("Property `context_parallel_enabled` must be implemented by subclass.")

    @property
    def tensor_parallel_enabled(self):
        raise NotImplementedError("Property `tensor_parallel_enabled` must be implemented by subclass.")