|
|
|
import glob |
|
import os |
|
import os.path as osp |
|
import warnings |
|
|
|
import mmcv |
|
import torch |
|
from mmcv.utils import TORCH_VERSION, digit_version, print_log |
|
|
|
|
|
def find_latest_checkpoint(path, suffix='pth'): |
|
"""Find the latest checkpoint from the working directory. |
|
|
|
Args: |
|
path(str): The path to find checkpoints. |
|
suffix(str): File extension. |
|
Defaults to pth. |
|
|
|
Returns: |
|
latest_path(str | None): File path of the latest checkpoint. |
|
References: |
|
.. [1] https://github.com/microsoft/SoftTeacher |
|
/blob/main/ssod/utils/patch.py |
|
""" |
|
if not osp.exists(path): |
|
warnings.warn('The path of checkpoints does not exist.') |
|
return None |
|
if osp.exists(osp.join(path, f'latest.{suffix}')): |
|
return osp.join(path, f'latest.{suffix}') |
|
|
|
checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) |
|
if len(checkpoints) == 0: |
|
warnings.warn('There are no checkpoints in the path.') |
|
return None |
|
latest = -1 |
|
latest_path = None |
|
for checkpoint in checkpoints: |
|
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) |
|
if count > latest: |
|
latest = count |
|
latest_path = checkpoint |
|
return latest_path |
|
|
|
|
|
def update_data_root(cfg, logger=None): |
|
"""Update data root according to env MMDET_DATASETS. |
|
|
|
If set env MMDET_DATASETS, update cfg.data_root according to |
|
MMDET_DATASETS. Otherwise, using cfg.data_root as default. |
|
|
|
Args: |
|
cfg (mmcv.Config): The model config need to modify |
|
logger (logging.Logger | str | None): the way to print msg |
|
""" |
|
assert isinstance(cfg, mmcv.Config), \ |
|
f'cfg got wrong type: {type(cfg)}, expected mmcv.Config' |
|
|
|
if 'MMDET_DATASETS' in os.environ: |
|
dst_root = os.environ['MMDET_DATASETS'] |
|
print_log(f'MMDET_DATASETS has been set to be {dst_root}.' |
|
f'Using {dst_root} as data root.') |
|
else: |
|
return |
|
|
|
assert isinstance(cfg, mmcv.Config), \ |
|
f'cfg got wrong type: {type(cfg)}, expected mmcv.Config' |
|
|
|
def update(cfg, src_str, dst_str): |
|
for k, v in cfg.items(): |
|
if isinstance(v, mmcv.ConfigDict): |
|
update(cfg[k], src_str, dst_str) |
|
if isinstance(v, str) and src_str in v: |
|
cfg[k] = v.replace(src_str, dst_str) |
|
|
|
update(cfg.data, cfg.data_root, dst_root) |
|
cfg.data_root = dst_root |
|
|
|
|
|
_torch_version_div_indexing = ( |
|
'parrots' not in TORCH_VERSION |
|
and digit_version(TORCH_VERSION) >= digit_version('1.8')) |
|
|
|
|
|
def floordiv(dividend, divisor, rounding_mode='trunc'): |
|
if _torch_version_div_indexing: |
|
return torch.div(dividend, divisor, rounding_mode=rounding_mode) |
|
else: |
|
return dividend // divisor |
|
|