Spaces:
Runtime error
Runtime error
File size: 5,977 Bytes
3e06e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from mmengine.config import ConfigDict
def compat_cfg(cfg):
"""This function would modify some filed to keep the compatibility of
config.
For example, it will move some args which will be deprecated to the correct
fields.
"""
cfg = copy.deepcopy(cfg)
cfg = compat_imgs_per_gpu(cfg)
cfg = compat_loader_args(cfg)
cfg = compat_runner_args(cfg)
return cfg
def compat_runner_args(cfg):
if 'runner' not in cfg:
cfg.runner = ConfigDict({
'type': 'EpochBasedRunner',
'max_epochs': cfg.total_epochs
})
warnings.warn(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning)
else:
if 'total_epochs' in cfg:
assert cfg.total_epochs == cfg.runner.max_epochs
return cfg
def compat_imgs_per_gpu(cfg):
cfg = copy.deepcopy(cfg)
if 'imgs_per_gpu' in cfg.data:
warnings.warn('"imgs_per_gpu" is deprecated in MMDet V2.0. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
warnings.warn(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
warnings.warn('Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
return cfg
def compat_loader_args(cfg):
"""Deprecated sample_per_gpu in cfg.data."""
cfg = copy.deepcopy(cfg)
if 'train_dataloader' not in cfg.data:
cfg.data['train_dataloader'] = ConfigDict()
if 'val_dataloader' not in cfg.data:
cfg.data['val_dataloader'] = ConfigDict()
if 'test_dataloader' not in cfg.data:
cfg.data['test_dataloader'] = ConfigDict()
# special process for train_dataloader
if 'samples_per_gpu' in cfg.data:
samples_per_gpu = cfg.data.pop('samples_per_gpu')
assert 'samples_per_gpu' not in \
cfg.data.train_dataloader, ('`samples_per_gpu` are set '
'in `data` field and ` '
'data.train_dataloader` '
'at the same time. '
'Please only set it in '
'`data.train_dataloader`. ')
cfg.data.train_dataloader['samples_per_gpu'] = samples_per_gpu
if 'persistent_workers' in cfg.data:
persistent_workers = cfg.data.pop('persistent_workers')
assert 'persistent_workers' not in \
cfg.data.train_dataloader, ('`persistent_workers` are set '
'in `data` field and ` '
'data.train_dataloader` '
'at the same time. '
'Please only set it in '
'`data.train_dataloader`. ')
cfg.data.train_dataloader['persistent_workers'] = persistent_workers
if 'workers_per_gpu' in cfg.data:
workers_per_gpu = cfg.data.pop('workers_per_gpu')
cfg.data.train_dataloader['workers_per_gpu'] = workers_per_gpu
cfg.data.val_dataloader['workers_per_gpu'] = workers_per_gpu
cfg.data.test_dataloader['workers_per_gpu'] = workers_per_gpu
# special process for val_dataloader
if 'samples_per_gpu' in cfg.data.val:
# keep default value of `sample_per_gpu` is 1
assert 'samples_per_gpu' not in \
cfg.data.val_dataloader, ('`samples_per_gpu` are set '
'in `data.val` field and ` '
'data.val_dataloader` at '
'the same time. '
'Please only set it in '
'`data.val_dataloader`. ')
cfg.data.val_dataloader['samples_per_gpu'] = \
cfg.data.val.pop('samples_per_gpu')
# special process for val_dataloader
# in case the test dataset is concatenated
if isinstance(cfg.data.test, dict):
if 'samples_per_gpu' in cfg.data.test:
assert 'samples_per_gpu' not in \
cfg.data.test_dataloader, ('`samples_per_gpu` are set '
'in `data.test` field and ` '
'data.test_dataloader` '
'at the same time. '
'Please only set it in '
'`data.test_dataloader`. ')
cfg.data.test_dataloader['samples_per_gpu'] = \
cfg.data.test.pop('samples_per_gpu')
elif isinstance(cfg.data.test, list):
for ds_cfg in cfg.data.test:
if 'samples_per_gpu' in ds_cfg:
assert 'samples_per_gpu' not in \
cfg.data.test_dataloader, ('`samples_per_gpu` are set '
'in `data.test` field and ` '
'data.test_dataloader` at'
' the same time. '
'Please only set it in '
'`data.test_dataloader`. ')
samples_per_gpu = max(
[ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
cfg.data.test_dataloader['samples_per_gpu'] = samples_per_gpu
return cfg
|