Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import warnings | |
| import mmcv | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from mmcv.parallel import MMDataParallel, MMDistributedDataParallel | |
| from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, | |
| Fp16OptimizerHook, OptimizerHook, build_optimizer, | |
| build_runner, get_dist_info) | |
| from mmdet.core import DistEvalHook, EvalHook | |
| from mmdet.datasets import build_dataloader, build_dataset | |
| from mmocr import digit_version | |
| from mmocr.apis.utils import (disable_text_recog_aug_test, | |
| replace_image_to_tensor) | |
| from mmocr.utils import get_root_logger | |
| def train_detector(model, | |
| dataset, | |
| cfg, | |
| distributed=False, | |
| validate=False, | |
| timestamp=None, | |
| meta=None): | |
| logger = get_root_logger(cfg.log_level) | |
| # prepare data loaders | |
| dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] | |
| # step 1: give default values and override (if exist) from cfg.data | |
| loader_cfg = { | |
| **dict( | |
| seed=cfg.get('seed'), | |
| drop_last=False, | |
| dist=distributed, | |
| num_gpus=len(cfg.gpu_ids)), | |
| **({} if torch.__version__ != 'parrots' else dict( | |
| prefetch_num=2, | |
| pin_memory=False, | |
| )), | |
| **dict((k, cfg.data[k]) for k in [ | |
| 'samples_per_gpu', | |
| 'workers_per_gpu', | |
| 'shuffle', | |
| 'seed', | |
| 'drop_last', | |
| 'prefetch_num', | |
| 'pin_memory', | |
| 'persistent_workers', | |
| ] if k in cfg.data) | |
| } | |
| # step 2: cfg.data.train_dataloader has highest priority | |
| train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {})) | |
| data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset] | |
| # put model on gpus | |
| if distributed: | |
| find_unused_parameters = cfg.get('find_unused_parameters', False) | |
| # Sets the `find_unused_parameters` parameter in | |
| # torch.nn.parallel.DistributedDataParallel | |
| model = MMDistributedDataParallel( | |
| model.cuda(), | |
| device_ids=[torch.cuda.current_device()], | |
| broadcast_buffers=False, | |
| find_unused_parameters=find_unused_parameters) | |
| else: | |
| if not torch.cuda.is_available(): | |
| assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \ | |
| 'Please use MMCV >= 1.4.4 for CPU training!' | |
| model = MMDataParallel(model, device_ids=cfg.gpu_ids) | |
| # build runner | |
| optimizer = build_optimizer(model, cfg.optimizer) | |
| if 'runner' not in cfg: | |
| cfg.runner = { | |
| 'type': 'EpochBasedRunner', | |
| 'max_epochs': cfg.total_epochs | |
| } | |
| warnings.warn( | |
| 'config is now expected to have a `runner` section, ' | |
| 'please set `runner` in your config.', UserWarning) | |
| else: | |
| if 'total_epochs' in cfg: | |
| assert cfg.total_epochs == cfg.runner.max_epochs | |
| runner = build_runner( | |
| cfg.runner, | |
| default_args=dict( | |
| model=model, | |
| optimizer=optimizer, | |
| work_dir=cfg.work_dir, | |
| logger=logger, | |
| meta=meta)) | |
| # an ugly workaround to make .log and .log.json filenames the same | |
| runner.timestamp = timestamp | |
| # fp16 setting | |
| fp16_cfg = cfg.get('fp16', None) | |
| if fp16_cfg is not None: | |
| optimizer_config = Fp16OptimizerHook( | |
| **cfg.optimizer_config, **fp16_cfg, distributed=distributed) | |
| elif distributed and 'type' not in cfg.optimizer_config: | |
| optimizer_config = OptimizerHook(**cfg.optimizer_config) | |
| else: | |
| optimizer_config = cfg.optimizer_config | |
| # register hooks | |
| runner.register_training_hooks( | |
| cfg.lr_config, | |
| optimizer_config, | |
| cfg.checkpoint_config, | |
| cfg.log_config, | |
| cfg.get('momentum_config', None), | |
| custom_hooks_config=cfg.get('custom_hooks', None)) | |
| if distributed: | |
| if isinstance(runner, EpochBasedRunner): | |
| runner.register_hook(DistSamplerSeedHook()) | |
| # register eval hooks | |
| if validate: | |
| val_samples_per_gpu = (cfg.data.get('val_dataloader', {})).get( | |
| 'samples_per_gpu', cfg.data.get('samples_per_gpu', 1)) | |
| if val_samples_per_gpu > 1: | |
| # Support batch_size > 1 in test for text recognition | |
| # by disable MultiRotateAugOCR since it is useless for most case | |
| cfg = disable_text_recog_aug_test(cfg) | |
| cfg = replace_image_to_tensor(cfg) | |
| val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) | |
| val_loader_cfg = { | |
| **loader_cfg, | |
| **dict(shuffle=False, drop_last=False), | |
| **cfg.data.get('val_dataloader', {}), | |
| **dict(samples_per_gpu=val_samples_per_gpu) | |
| } | |
| val_dataloader = build_dataloader(val_dataset, **val_loader_cfg) | |
| eval_cfg = cfg.get('evaluation', {}) | |
| eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' | |
| eval_hook = DistEvalHook if distributed else EvalHook | |
| runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) | |
| if cfg.resume_from: | |
| runner.resume(cfg.resume_from) | |
| elif cfg.load_from: | |
| runner.load_checkpoint(cfg.load_from) | |
| runner.run(data_loaders, cfg.workflow) | |
| def init_random_seed(seed=None, device='cuda'): | |
| """Initialize random seed. If the seed is None, it will be replaced by a | |
| random number, and then broadcasted to all processes. | |
| Args: | |
| seed (int, Optional): The seed. | |
| device (str): The device where the seed will be put on. | |
| Returns: | |
| int: Seed to be used. | |
| """ | |
| if seed is not None: | |
| return seed | |
| # Make sure all ranks share the same random seed to prevent | |
| # some potential bugs. Please refer to | |
| # https://github.com/open-mmlab/mmdetection/issues/6339 | |
| rank, world_size = get_dist_info() | |
| seed = np.random.randint(2**31) | |
| if world_size == 1: | |
| return seed | |
| if rank == 0: | |
| random_num = torch.tensor(seed, dtype=torch.int32, device=device) | |
| else: | |
| random_num = torch.tensor(0, dtype=torch.int32, device=device) | |
| dist.broadcast(random_num, src=0) | |
| return random_num.item() | |