Spaces:
Runtime error
Runtime error
File size: 2,914 Bytes
3e06e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
# Copyright (c) OpenMMLab. All rights reserved.
import re
from mmengine.config import Config
def replace_cfg_vals(ori_cfg):
"""Replace the string "${key}" with the corresponding value.
Replace the "${key}" with the value of ori_cfg.key in the config. And
support replacing the chained ${key}. Such as, replace "${key0.key1}"
with the value of cfg.key0.key1. Code is modified from `vars.py
< https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/vars.py>`_ # noqa: E501
Args:
ori_cfg (mmengine.config.Config):
The origin config with "${key}" generated from a file.
Returns:
updated_cfg [mmengine.config.Config]:
The config with "${key}" replaced by the corresponding value.
"""
def get_value(cfg, key):
for k in key.split('.'):
cfg = cfg[k]
return cfg
def replace_value(cfg):
if isinstance(cfg, dict):
return {key: replace_value(value) for key, value in cfg.items()}
elif isinstance(cfg, list):
return [replace_value(item) for item in cfg]
elif isinstance(cfg, tuple):
return tuple([replace_value(item) for item in cfg])
elif isinstance(cfg, str):
# the format of string cfg may be:
# 1) "${key}", which will be replaced with cfg.key directly
# 2) "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx",
# which will be replaced with the string of the cfg.key
keys = pattern_key.findall(cfg)
values = [get_value(ori_cfg, key[2:-1]) for key in keys]
if len(keys) == 1 and keys[0] == cfg:
# the format of string cfg is "${key}"
cfg = values[0]
else:
for key, value in zip(keys, values):
# the format of string cfg is
# "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx"
assert not isinstance(value, (dict, list, tuple)), \
f'for the format of string cfg is ' \
f"'xxxxx${key}xxxxx' or 'xxx${key}xxx${key}xxx', " \
f"the type of the value of '${key}' " \
f'can not be dict, list, or tuple' \
f'but you input {type(value)} in {cfg}'
cfg = cfg.replace(key, str(value))
return cfg
else:
return cfg
# the pattern of string "${key}"
pattern_key = re.compile(r'\$\{[a-zA-Z\d_.]*\}')
# the type of ori_cfg._cfg_dict is mmengine.config.ConfigDict
updated_cfg = Config(
replace_value(ori_cfg._cfg_dict), filename=ori_cfg.filename)
# replace the model with model_wrapper
if updated_cfg.get('model_wrapper', None) is not None:
updated_cfg.model = updated_cfg.model_wrapper
updated_cfg.pop('model_wrapper')
return updated_cfg
|