glenn-jocher pre-commit-ci[bot] commited on
Commit
9cf5fd5
·
unverified ·
1 Parent(s): 51fb467

assert torch!=1.12.0 for DDP training (#8621)

Browse files

* assert torch!=1.12.0 for DDP training

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Files changed (3) hide show
  1. requirements.txt +2 -2
  2. train.py +5 -9
  3. utils/torch_utils.py +17 -1
requirements.txt CHANGED
@@ -9,8 +9,8 @@ Pillow>=7.1.2
9
  PyYAML>=5.3.1
10
  requests>=2.23.0
11
  scipy>=1.4.1
12
- torch>=1.7.0,!=1.12.0 # https://github.com/ultralytics/yolov5/issues/8395
13
- torchvision>=0.8.1,!=0.13.0 # https://github.com/ultralytics/yolov5/issues/8395
14
  tqdm>=4.64.0
15
  protobuf<4.21.3 # https://github.com/ultralytics/yolov5/issues/8012
16
 
 
9
  PyYAML>=5.3.1
10
  requests>=2.23.0
11
  scipy>=1.4.1
12
+ torch>=1.7.0
13
+ torchvision>=0.8.1
14
  tqdm>=4.64.0
15
  protobuf<4.21.3 # https://github.com/ultralytics/yolov5/issues/8012
16
 
train.py CHANGED
@@ -27,7 +27,6 @@ import torch
27
  import torch.distributed as dist
28
  import torch.nn as nn
29
  import yaml
30
- from torch.nn.parallel import DistributedDataParallel as DDP
31
  from torch.optim import lr_scheduler
32
  from tqdm import tqdm
33
 
@@ -46,15 +45,15 @@ from utils.callbacks import Callbacks
46
  from utils.dataloaders import create_dataloader
47
  from utils.downloads import attempt_download
48
  from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
49
- check_requirements, check_suffix, check_version, check_yaml, colorstr, get_latest_run,
50
- increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
51
- labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer)
52
  from utils.loggers import Loggers
53
  from utils.loggers.wandb.wandb_utils import check_wandb_resume
54
  from utils.loss import ComputeLoss
55
  from utils.metrics import fitness
56
  from utils.plots import plot_evolve, plot_labels
57
- from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_optimizer,
58
  torch_distributed_zero_first)
59
 
60
  LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
@@ -248,10 +247,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
248
 
249
  # DDP mode
250
  if cuda and RANK != -1:
251
- if check_version(torch.__version__, '1.11.0'):
252
- model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
253
- else:
254
- model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
255
 
256
  # Model attributes
257
  nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
 
27
  import torch.distributed as dist
28
  import torch.nn as nn
29
  import yaml
 
30
  from torch.optim import lr_scheduler
31
  from tqdm import tqdm
32
 
 
45
  from utils.dataloaders import create_dataloader
46
  from utils.downloads import attempt_download
47
  from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
48
+ check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path,
49
+ init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
50
+ one_cycle, print_args, print_mutation, strip_optimizer)
51
  from utils.loggers import Loggers
52
  from utils.loggers.wandb.wandb_utils import check_wandb_resume
53
  from utils.loss import ComputeLoss
54
  from utils.metrics import fitness
55
  from utils.plots import plot_evolve, plot_labels
56
+ from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
57
  torch_distributed_zero_first)
58
 
59
  LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
 
247
 
248
  # DDP mode
249
  if cuda and RANK != -1:
250
+ model = smart_DDP(model)
 
 
 
251
 
252
  # Model attributes
253
  nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
utils/torch_utils.py CHANGED
@@ -17,8 +17,13 @@ import torch
17
  import torch.distributed as dist
18
  import torch.nn as nn
19
  import torch.nn.functional as F
 
20
 
21
- from utils.general import LOGGER, colorstr, file_date, git_describe
 
 
 
 
22
 
23
  try:
24
  import thop # for FLOPs computation
@@ -29,6 +34,17 @@ except ImportError:
29
  warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
32
  @contextmanager
33
  def torch_distributed_zero_first(local_rank: int):
34
  # Decorator to make all processes in distributed training wait for each local_master to do something
 
17
  import torch.distributed as dist
18
  import torch.nn as nn
19
  import torch.nn.functional as F
20
+ from torch.nn.parallel import DistributedDataParallel as DDP
21
 
22
+ from utils.general import LOGGER, check_version, colorstr, file_date, git_describe
23
+
24
+ LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
25
+ RANK = int(os.getenv('RANK', -1))
26
+ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
27
 
28
  try:
29
  import thop # for FLOPs computation
 
34
  warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
35
 
36
 
37
+ def smart_DDP(model):
38
+ # Model DDP creation with checks
39
+ assert not check_version(torch.__version__, '1.12.0', pinned=True), \
40
+ 'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
41
+ 'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
42
+ if check_version(torch.__version__, '1.11.0'):
43
+ return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
44
+ else:
45
+ return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
46
+
47
+
48
  @contextmanager
49
  def torch_distributed_zero_first(local_rank: int):
50
  # Decorator to make all processes in distributed training wait for each local_master to do something