Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
import mmcv | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel | |
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, | |
Fp16OptimizerHook, OptimizerHook, build_optimizer, | |
build_runner, get_dist_info) | |
from mmdet.core import DistEvalHook, EvalHook | |
from mmdet.datasets import build_dataloader, build_dataset | |
from mmocr import digit_version | |
from mmocr.apis.utils import (disable_text_recog_aug_test, | |
replace_image_to_tensor) | |
from mmocr.utils import get_root_logger | |
def train_detector(model, | |
dataset, | |
cfg, | |
distributed=False, | |
validate=False, | |
timestamp=None, | |
meta=None): | |
logger = get_root_logger(cfg.log_level) | |
# prepare data loaders | |
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] | |
# step 1: give default values and override (if exist) from cfg.data | |
loader_cfg = { | |
**dict( | |
seed=cfg.get('seed'), | |
drop_last=False, | |
dist=distributed, | |
num_gpus=len(cfg.gpu_ids)), | |
**({} if torch.__version__ != 'parrots' else dict( | |
prefetch_num=2, | |
pin_memory=False, | |
)), | |
**dict((k, cfg.data[k]) for k in [ | |
'samples_per_gpu', | |
'workers_per_gpu', | |
'shuffle', | |
'seed', | |
'drop_last', | |
'prefetch_num', | |
'pin_memory', | |
'persistent_workers', | |
] if k in cfg.data) | |
} | |
# step 2: cfg.data.train_dataloader has highest priority | |
train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {})) | |
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset] | |
# put model on gpus | |
if distributed: | |
find_unused_parameters = cfg.get('find_unused_parameters', False) | |
# Sets the `find_unused_parameters` parameter in | |
# torch.nn.parallel.DistributedDataParallel | |
model = MMDistributedDataParallel( | |
model.cuda(), | |
device_ids=[torch.cuda.current_device()], | |
broadcast_buffers=False, | |
find_unused_parameters=find_unused_parameters) | |
else: | |
if not torch.cuda.is_available(): | |
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \ | |
'Please use MMCV >= 1.4.4 for CPU training!' | |
model = MMDataParallel(model, device_ids=cfg.gpu_ids) | |
# build runner | |
optimizer = build_optimizer(model, cfg.optimizer) | |
if 'runner' not in cfg: | |
cfg.runner = { | |
'type': 'EpochBasedRunner', | |
'max_epochs': cfg.total_epochs | |
} | |
warnings.warn( | |
'config is now expected to have a `runner` section, ' | |
'please set `runner` in your config.', UserWarning) | |
else: | |
if 'total_epochs' in cfg: | |
assert cfg.total_epochs == cfg.runner.max_epochs | |
runner = build_runner( | |
cfg.runner, | |
default_args=dict( | |
model=model, | |
optimizer=optimizer, | |
work_dir=cfg.work_dir, | |
logger=logger, | |
meta=meta)) | |
# an ugly workaround to make .log and .log.json filenames the same | |
runner.timestamp = timestamp | |
# fp16 setting | |
fp16_cfg = cfg.get('fp16', None) | |
if fp16_cfg is not None: | |
optimizer_config = Fp16OptimizerHook( | |
**cfg.optimizer_config, **fp16_cfg, distributed=distributed) | |
elif distributed and 'type' not in cfg.optimizer_config: | |
optimizer_config = OptimizerHook(**cfg.optimizer_config) | |
else: | |
optimizer_config = cfg.optimizer_config | |
# register hooks | |
runner.register_training_hooks( | |
cfg.lr_config, | |
optimizer_config, | |
cfg.checkpoint_config, | |
cfg.log_config, | |
cfg.get('momentum_config', None), | |
custom_hooks_config=cfg.get('custom_hooks', None)) | |
if distributed: | |
if isinstance(runner, EpochBasedRunner): | |
runner.register_hook(DistSamplerSeedHook()) | |
# register eval hooks | |
if validate: | |
val_samples_per_gpu = (cfg.data.get('val_dataloader', {})).get( | |
'samples_per_gpu', cfg.data.get('samples_per_gpu', 1)) | |
if val_samples_per_gpu > 1: | |
# Support batch_size > 1 in test for text recognition | |
# by disable MultiRotateAugOCR since it is useless for most case | |
cfg = disable_text_recog_aug_test(cfg) | |
cfg = replace_image_to_tensor(cfg) | |
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) | |
val_loader_cfg = { | |
**loader_cfg, | |
**dict(shuffle=False, drop_last=False), | |
**cfg.data.get('val_dataloader', {}), | |
**dict(samples_per_gpu=val_samples_per_gpu) | |
} | |
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg) | |
eval_cfg = cfg.get('evaluation', {}) | |
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' | |
eval_hook = DistEvalHook if distributed else EvalHook | |
runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) | |
if cfg.resume_from: | |
runner.resume(cfg.resume_from) | |
elif cfg.load_from: | |
runner.load_checkpoint(cfg.load_from) | |
runner.run(data_loaders, cfg.workflow) | |
def init_random_seed(seed=None, device='cuda'): | |
"""Initialize random seed. If the seed is None, it will be replaced by a | |
random number, and then broadcasted to all processes. | |
Args: | |
seed (int, Optional): The seed. | |
device (str): The device where the seed will be put on. | |
Returns: | |
int: Seed to be used. | |
""" | |
if seed is not None: | |
return seed | |
# Make sure all ranks share the same random seed to prevent | |
# some potential bugs. Please refer to | |
# https://github.com/open-mmlab/mmdetection/issues/6339 | |
rank, world_size = get_dist_info() | |
seed = np.random.randint(2**31) | |
if world_size == 1: | |
return seed | |
if rank == 0: | |
random_num = torch.tensor(seed, dtype=torch.int32, device=device) | |
else: | |
random_num = torch.tensor(0, dtype=torch.int32, device=device) | |
dist.broadcast(random_num, src=0) | |
return random_num.item() | |