assert torch!=1.12.0 for DDP training (#8621)
Browse files* assert torch!=1.12.0 for DDP training
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- requirements.txt +2 -2
- train.py +5 -9
- utils/torch_utils.py +17 -1
requirements.txt
CHANGED
@@ -9,8 +9,8 @@ Pillow>=7.1.2
|
|
9 |
PyYAML>=5.3.1
|
10 |
requests>=2.23.0
|
11 |
scipy>=1.4.1
|
12 |
-
torch>=1.7.0
|
13 |
-
torchvision>=0.8.1
|
14 |
tqdm>=4.64.0
|
15 |
protobuf<4.21.3 # https://github.com/ultralytics/yolov5/issues/8012
|
16 |
|
|
|
9 |
PyYAML>=5.3.1
|
10 |
requests>=2.23.0
|
11 |
scipy>=1.4.1
|
12 |
+
torch>=1.7.0
|
13 |
+
torchvision>=0.8.1
|
14 |
tqdm>=4.64.0
|
15 |
protobuf<4.21.3 # https://github.com/ultralytics/yolov5/issues/8012
|
16 |
|
train.py
CHANGED
@@ -27,7 +27,6 @@ import torch
|
|
27 |
import torch.distributed as dist
|
28 |
import torch.nn as nn
|
29 |
import yaml
|
30 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
31 |
from torch.optim import lr_scheduler
|
32 |
from tqdm import tqdm
|
33 |
|
@@ -46,15 +45,15 @@ from utils.callbacks import Callbacks
|
|
46 |
from utils.dataloaders import create_dataloader
|
47 |
from utils.downloads import attempt_download
|
48 |
from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
|
49 |
-
check_requirements, check_suffix,
|
50 |
-
|
51 |
-
|
52 |
from utils.loggers import Loggers
|
53 |
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
54 |
from utils.loss import ComputeLoss
|
55 |
from utils.metrics import fitness
|
56 |
from utils.plots import plot_evolve, plot_labels
|
57 |
-
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_optimizer,
|
58 |
torch_distributed_zero_first)
|
59 |
|
60 |
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
@@ -248,10 +247,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|
248 |
|
249 |
# DDP mode
|
250 |
if cuda and RANK != -1:
|
251 |
-
|
252 |
-
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
|
253 |
-
else:
|
254 |
-
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
255 |
|
256 |
# Model attributes
|
257 |
nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
|
|
|
27 |
import torch.distributed as dist
|
28 |
import torch.nn as nn
|
29 |
import yaml
|
|
|
30 |
from torch.optim import lr_scheduler
|
31 |
from tqdm import tqdm
|
32 |
|
|
|
45 |
from utils.dataloaders import create_dataloader
|
46 |
from utils.downloads import attempt_download
|
47 |
from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
|
48 |
+
check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path,
|
49 |
+
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
|
50 |
+
one_cycle, print_args, print_mutation, strip_optimizer)
|
51 |
from utils.loggers import Loggers
|
52 |
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
53 |
from utils.loss import ComputeLoss
|
54 |
from utils.metrics import fitness
|
55 |
from utils.plots import plot_evolve, plot_labels
|
56 |
+
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
|
57 |
torch_distributed_zero_first)
|
58 |
|
59 |
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
|
|
247 |
|
248 |
# DDP mode
|
249 |
if cuda and RANK != -1:
|
250 |
+
model = smart_DDP(model)
|
|
|
|
|
|
|
251 |
|
252 |
# Model attributes
|
253 |
nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
|
utils/torch_utils.py
CHANGED
@@ -17,8 +17,13 @@ import torch
|
|
17 |
import torch.distributed as dist
|
18 |
import torch.nn as nn
|
19 |
import torch.nn.functional as F
|
|
|
20 |
|
21 |
-
from utils.general import LOGGER, colorstr, file_date, git_describe
|
|
|
|
|
|
|
|
|
22 |
|
23 |
try:
|
24 |
import thop # for FLOPs computation
|
@@ -29,6 +34,17 @@ except ImportError:
|
|
29 |
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
|
30 |
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
@contextmanager
|
33 |
def torch_distributed_zero_first(local_rank: int):
|
34 |
# Decorator to make all processes in distributed training wait for each local_master to do something
|
|
|
17 |
import torch.distributed as dist
|
18 |
import torch.nn as nn
|
19 |
import torch.nn.functional as F
|
20 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
21 |
|
22 |
+
from utils.general import LOGGER, check_version, colorstr, file_date, git_describe
|
23 |
+
|
24 |
+
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
25 |
+
RANK = int(os.getenv('RANK', -1))
|
26 |
+
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
27 |
|
28 |
try:
|
29 |
import thop # for FLOPs computation
|
|
|
34 |
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
|
35 |
|
36 |
|
37 |
+
def smart_DDP(model):
|
38 |
+
# Model DDP creation with checks
|
39 |
+
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
40 |
+
'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
|
41 |
+
'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
|
42 |
+
if check_version(torch.__version__, '1.11.0'):
|
43 |
+
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
|
44 |
+
else:
|
45 |
+
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
46 |
+
|
47 |
+
|
48 |
@contextmanager
|
49 |
def torch_distributed_zero_first(local_rank: int):
|
50 |
# Decorator to make all processes in distributed training wait for each local_master to do something
|