ICON / lib /pymaf /core /base_trainer.py
Yuliang's picture
done
2d5f249
raw history blame
No virus
3.73 kB
# This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/utils/base_trainer.py
from __future__ import division
import logging
from utils import CheckpointSaver
from tensorboardX import SummaryWriter
import torch
from tqdm import tqdm
tqdm.monitor_interval = 0
logger = logging.getLogger(__name__)
class BaseTrainer(object):
"""Base class for Trainer objects.
Takes care of checkpointing/logging/resuming training.
"""
def __init__(self, options):
self.options = options
if options.multiprocessing_distributed:
self.device = torch.device('cuda', options.gpu)
else:
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
# override this function to define your model, optimizers etc.
self.saver = CheckpointSaver(save_dir=options.checkpoint_dir,
overwrite=options.overwrite)
if options.rank == 0:
self.summary_writer = SummaryWriter(self.options.summary_dir)
self.init_fn()
self.checkpoint = None
if options.resume and self.saver.exists_checkpoint():
self.checkpoint = self.saver.load_checkpoint(
self.models_dict, self.optimizers_dict)
if self.checkpoint is None:
self.epoch_count = 0
self.step_count = 0
else:
self.epoch_count = self.checkpoint['epoch']
self.step_count = self.checkpoint['total_step_count']
if self.checkpoint is not None:
self.checkpoint_batch_idx = self.checkpoint['batch_idx']
else:
self.checkpoint_batch_idx = 0
self.best_performance = float('inf')
def load_pretrained(self, checkpoint_file=None):
"""Load a pretrained checkpoint.
This is different from resuming training using --resume.
"""
if checkpoint_file is not None:
checkpoint = torch.load(checkpoint_file)
for model in self.models_dict:
if model in checkpoint:
self.models_dict[model].load_state_dict(checkpoint[model],
strict=True)
print(f'Checkpoint {model} loaded')
def move_dict_to_device(self, dict, device, tensor2float=False):
for k, v in dict.items():
if isinstance(v, torch.Tensor):
if tensor2float:
dict[k] = v.float().to(device)
else:
dict[k] = v.to(device)
# The following methods (with the possible exception of test) have to be implemented in the derived classes
def train(self, epoch):
raise NotImplementedError('You need to provide an train method')
def init_fn(self):
raise NotImplementedError('You need to provide an _init_fn method')
def train_step(self, input_batch):
raise NotImplementedError('You need to provide a _train_step method')
def train_summaries(self, input_batch):
raise NotImplementedError(
'You need to provide a _train_summaries method')
def visualize(self, input_batch):
raise NotImplementedError('You need to provide a visualize method')
def validate(self):
pass
def test(self):
pass
def evaluate(self):
pass
def fit(self):
# Run training for num_epochs epochs
for epoch in tqdm(range(self.epoch_count, self.options.num_epochs),
total=self.options.num_epochs,
initial=self.epoch_count):
self.epoch_count = epoch
self.train(epoch)
return