CapeX / models /apis /train.py
matanru's picture
initial commit
93b49a4
raw
history blame
4.85 kB
import os
import torch
from models.core.custom_hooks.shuffle_hooks import ShufflePairedSamplesHook
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
build_optimizer)
from mmpose.core import DistEvalHook, EvalHook, Fp16OptimizerHook
from mmpose.datasets import build_dataloader
from mmpose.utils import get_root_logger
def train_model(model,
dataset,
val_dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
"""Train model entry function.
Args:
model (nn.Module): The model to be trained.
dataset (Dataset): Train dataset.
cfg (dict): The config dict for training.
distributed (bool): Whether to use distributed training.
Default: False.
validate (bool): Whether to do evaluation. Default: False.
timestamp (str | None): Local time for runner. Default: None.
meta (dict | None): Meta dict to record some important information.
Default: None
"""
logger = get_root_logger(cfg.log_level)
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
dataloader_setting = dict(
samples_per_gpu=cfg.data.get('samples_per_gpu', {}),
workers_per_gpu=cfg.data.get('workers_per_gpu', {}),
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed,
pin_memory=False,
)
dataloader_setting = dict(dataloader_setting,
**cfg.data.get('train_dataloader', {}))
data_loaders = [
build_dataloader(ds, **dataloader_setting) for ds in dataset
]
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters',
False) # NOTE: True has been modified to False for faster training.
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
runner = EpochBasedRunner(
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta)
# an ugly workaround to make .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = OptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
if distributed:
runner.register_hook(DistSamplerSeedHook())
shuffle_cfg = cfg.get('shuffle_cfg', None)
if shuffle_cfg is not None:
for data_loader in data_loaders:
runner.register_hook(ShufflePairedSamplesHook(data_loader, **shuffle_cfg))
# register eval hooks
if validate:
eval_cfg = cfg.get('evaluation', {})
eval_cfg['res_folder'] = os.path.join(cfg.work_dir, eval_cfg['res_folder'])
dataloader_setting = dict(
# samples_per_gpu=cfg.data.get('samples_per_gpu', {}),
samples_per_gpu=1,
workers_per_gpu=cfg.data.get('workers_per_gpu', {}),
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
shuffle=False,
pin_memory=False,
)
dataloader_setting = dict(dataloader_setting,
**cfg.data.get('val_dataloader', {}))
val_dataloader = build_dataloader(val_dataset, **dataloader_setting)
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)