Spaces:
Running
Running
File size: 4,854 Bytes
241adf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
import torch
from models.core.custom_hooks.shuffle_hooks import ShufflePairedSamplesHook
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
build_optimizer)
from mmpose.core import DistEvalHook, EvalHook, Fp16OptimizerHook
from mmpose.datasets import build_dataloader
from mmpose.utils import get_root_logger
def train_model(model,
dataset,
val_dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
"""Train model entry function.
Args:
model (nn.Module): The model to be trained.
dataset (Dataset): Train dataset.
cfg (dict): The config dict for training.
distributed (bool): Whether to use distributed training.
Default: False.
validate (bool): Whether to do evaluation. Default: False.
timestamp (str | None): Local time for runner. Default: None.
meta (dict | None): Meta dict to record some important information.
Default: None
"""
logger = get_root_logger(cfg.log_level)
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
dataloader_setting = dict(
samples_per_gpu=cfg.data.get('samples_per_gpu', {}),
workers_per_gpu=cfg.data.get('workers_per_gpu', {}),
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed,
pin_memory=False,
)
dataloader_setting = dict(dataloader_setting,
**cfg.data.get('train_dataloader', {}))
data_loaders = [
build_dataloader(ds, **dataloader_setting) for ds in dataset
]
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters',
False) # NOTE: True has been modified to False for faster training.
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
runner = EpochBasedRunner(
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta)
# an ugly workaround to make .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = OptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
if distributed:
runner.register_hook(DistSamplerSeedHook())
shuffle_cfg = cfg.get('shuffle_cfg', None)
if shuffle_cfg is not None:
for data_loader in data_loaders:
runner.register_hook(ShufflePairedSamplesHook(data_loader, **shuffle_cfg))
# register eval hooks
if validate:
eval_cfg = cfg.get('evaluation', {})
eval_cfg['res_folder'] = os.path.join(cfg.work_dir, eval_cfg['res_folder'])
dataloader_setting = dict(
# samples_per_gpu=cfg.data.get('samples_per_gpu', {}),
samples_per_gpu=1,
workers_per_gpu=cfg.data.get('workers_per_gpu', {}),
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
shuffle=False,
pin_memory=False,
)
dataloader_setting = dict(dataloader_setting,
**cfg.data.get('val_dataloader', {}))
val_dataloader = build_dataloader(val_dataset, **dataloader_setting)
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
|