Spaces:
Configuration error
Configuration error
import matplotlib | |
from torch.nn import DataParallel | |
from torch.nn.parallel import DistributedDataParallel | |
matplotlib.use('Agg') | |
import glob | |
import itertools | |
import subprocess | |
import threading | |
import traceback | |
from pytorch_lightning.callbacks import GradientAccumulationScheduler | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from functools import wraps | |
from torch.cuda._utils import _get_device_index | |
import numpy as np | |
import torch.optim | |
import torch.utils.data | |
import copy | |
import logging | |
import os | |
import re | |
import sys | |
import torch | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
import tqdm | |
from torch.optim.optimizer import Optimizer | |
def get_a_var(obj): # pragma: no cover | |
if isinstance(obj, torch.Tensor): | |
return obj | |
if isinstance(obj, list) or isinstance(obj, tuple): | |
for result in map(get_a_var, obj): | |
if isinstance(result, torch.Tensor): | |
return result | |
if isinstance(obj, dict): | |
for result in map(get_a_var, obj.items()): | |
if isinstance(result, torch.Tensor): | |
return result | |
return None | |
def data_loader(fn): | |
""" | |
Decorator to make any fx with this use the lazy property | |
:param fn: | |
:return: | |
""" | |
wraps(fn) | |
attr_name = '_lazy_' + fn.__name__ | |
def _get_data_loader(self): | |
try: | |
value = getattr(self, attr_name) | |
except AttributeError: | |
try: | |
value = fn(self) # Lazy evaluation, done only once. | |
if ( | |
value is not None and | |
not isinstance(value, list) and | |
fn.__name__ in ['test_dataloader', 'val_dataloader'] | |
): | |
value = [value] | |
except AttributeError as e: | |
# Guard against AttributeError suppression. (Issue #142) | |
traceback.print_exc() | |
error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e) | |
raise RuntimeError(error) from e | |
setattr(self, attr_name, value) # Memoize evaluation. | |
return value | |
return _get_data_loader | |
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: no cover | |
r"""Applies each `module` in :attr:`modules` in parallel on arguments | |
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) | |
on each of :attr:`devices`. | |
Args: | |
modules (Module): modules to be parallelized | |
inputs (tensor): inputs to the modules | |
devices (list of int or torch.device): CUDA devices | |
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and | |
:attr:`devices` (if given) should all have same length. Moreover, each | |
element of :attr:`inputs` can either be a single object as the only argument | |
to a module, or a collection of positional arguments. | |
""" | |
assert len(modules) == len(inputs) | |
if kwargs_tup is not None: | |
assert len(modules) == len(kwargs_tup) | |
else: | |
kwargs_tup = ({},) * len(modules) | |
if devices is not None: | |
assert len(modules) == len(devices) | |
else: | |
devices = [None] * len(modules) | |
devices = list(map(lambda x: _get_device_index(x, True), devices)) | |
lock = threading.Lock() | |
results = {} | |
grad_enabled = torch.is_grad_enabled() | |
def _worker(i, module, input, kwargs, device=None): | |
torch.set_grad_enabled(grad_enabled) | |
if device is None: | |
device = get_a_var(input).get_device() | |
try: | |
with torch.cuda.device(device): | |
# this also avoids accidental slicing of `input` if it is a Tensor | |
if not isinstance(input, (list, tuple)): | |
input = (input,) | |
# --------------- | |
# CHANGE | |
if module.training: | |
output = module.training_step(*input, **kwargs) | |
elif module.testing: | |
output = module.test_step(*input, **kwargs) | |
else: | |
output = module.validation_step(*input, **kwargs) | |
# --------------- | |
with lock: | |
results[i] = output | |
except Exception as e: | |
with lock: | |
results[i] = e | |
# make sure each module knows what training state it's in... | |
# fixes weird bug where copies are out of sync | |
root_m = modules[0] | |
for m in modules[1:]: | |
m.training = root_m.training | |
m.testing = root_m.testing | |
if len(modules) > 1: | |
threads = [threading.Thread(target=_worker, | |
args=(i, module, input, kwargs, device)) | |
for i, (module, input, kwargs, device) in | |
enumerate(zip(modules, inputs, kwargs_tup, devices))] | |
for thread in threads: | |
thread.start() | |
for thread in threads: | |
thread.join() | |
else: | |
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) | |
outputs = [] | |
for i in range(len(inputs)): | |
output = results[i] | |
if isinstance(output, Exception): | |
raise output | |
outputs.append(output) | |
return outputs | |
def _find_tensors(obj): # pragma: no cover | |
r""" | |
Recursively find all tensors contained in the specified object. | |
""" | |
if isinstance(obj, torch.Tensor): | |
return [obj] | |
if isinstance(obj, (list, tuple)): | |
return itertools.chain(*map(_find_tensors, obj)) | |
if isinstance(obj, dict): | |
return itertools.chain(*map(_find_tensors, obj.values())) | |
return [] | |
class DDP(DistributedDataParallel): | |
""" | |
Override the forward call in lightning so it goes to training and validation step respectively | |
""" | |
def parallel_apply(self, replicas, inputs, kwargs): | |
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) | |
def forward(self, *inputs, **kwargs): # pragma: no cover | |
self._sync_params() | |
if self.device_ids: | |
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) | |
if len(self.device_ids) == 1: | |
# -------------- | |
# LIGHTNING MOD | |
# -------------- | |
# normal | |
# output = self.module(*inputs[0], **kwargs[0]) | |
# lightning | |
if self.module.training: | |
output = self.module.training_step(*inputs[0], **kwargs[0]) | |
elif self.module.testing: | |
output = self.module.test_step(*inputs[0], **kwargs[0]) | |
else: | |
output = self.module.validation_step(*inputs[0], **kwargs[0]) | |
else: | |
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) | |
output = self.gather(outputs, self.output_device) | |
else: | |
# normal | |
output = self.module(*inputs, **kwargs) | |
if torch.is_grad_enabled(): | |
# We'll return the output object verbatim since it is a freeform | |
# object. We need to find any tensors in this object, though, | |
# because we need to figure out which parameters were used during | |
# this forward pass, to ensure we short circuit reduction for any | |
# unused parameters. Only if `find_unused_parameters` is set. | |
if self.find_unused_parameters: | |
self.reducer.prepare_for_backward(list(_find_tensors(output))) | |
else: | |
self.reducer.prepare_for_backward([]) | |
return output | |
class DP(DataParallel): | |
""" | |
Override the forward call in lightning so it goes to training and validation step respectively | |
""" | |
def forward(self, *inputs, **kwargs): | |
if not self.device_ids: | |
return self.module(*inputs, **kwargs) | |
for t in itertools.chain(self.module.parameters(), self.module.buffers()): | |
if t.device != self.src_device_obj: | |
raise RuntimeError("module must have its parameters and buffers " | |
"on device {} (device_ids[0]) but found one of " | |
"them on device: {}".format(self.src_device_obj, t.device)) | |
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) | |
if len(self.device_ids) == 1: | |
# lightning | |
if self.module.training: | |
return self.module.training_step(*inputs[0], **kwargs[0]) | |
elif self.module.testing: | |
return self.module.test_step(*inputs[0], **kwargs[0]) | |
else: | |
return self.module.validation_step(*inputs[0], **kwargs[0]) | |
replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) | |
outputs = self.parallel_apply(replicas, inputs, kwargs) | |
return self.gather(outputs, self.output_device) | |
def parallel_apply(self, replicas, inputs, kwargs): | |
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) | |
class GradientAccumulationScheduler: | |
def __init__(self, scheduling: dict): | |
if scheduling == {}: # empty dict error | |
raise TypeError("Empty dict cannot be interpreted correct") | |
for key in scheduling.keys(): | |
if not isinstance(key, int) or not isinstance(scheduling[key], int): | |
raise TypeError("All epoches and accumulation factor must be integers") | |
minimal_epoch = min(scheduling.keys()) | |
if minimal_epoch < 1: | |
msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct" | |
raise IndexError(msg) | |
elif minimal_epoch != 1: # if user didnt define first epoch accumulation factor | |
scheduling.update({1: 1}) | |
self.scheduling = scheduling | |
self.epochs = sorted(scheduling.keys()) | |
def on_epoch_begin(self, epoch, trainer): | |
epoch += 1 # indexing epochs from 1 | |
for i in reversed(range(len(self.epochs))): | |
if epoch >= self.epochs[i]: | |
trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) | |
break | |
class LatestModelCheckpoint(ModelCheckpoint): | |
def __init__(self, filepath, monitor='val_loss', verbose=0, num_ckpt_keep=5, | |
save_weights_only=False, mode='auto', period=1, prefix='model', save_best=True): | |
super(ModelCheckpoint, self).__init__() | |
self.monitor = monitor | |
self.verbose = verbose | |
self.filepath = filepath | |
os.makedirs(filepath, exist_ok=True) | |
self.num_ckpt_keep = num_ckpt_keep | |
self.save_best = save_best | |
self.save_weights_only = save_weights_only | |
self.period = period | |
self.epochs_since_last_check = 0 | |
self.prefix = prefix | |
self.best_k_models = {} | |
# {filename: monitor} | |
self.kth_best_model = '' | |
self.save_top_k = 1 | |
self.task = None | |
if mode == 'min': | |
self.monitor_op = np.less | |
self.best = np.Inf | |
self.mode = 'min' | |
elif mode == 'max': | |
self.monitor_op = np.greater | |
self.best = -np.Inf | |
self.mode = 'max' | |
else: | |
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): | |
self.monitor_op = np.greater | |
self.best = -np.Inf | |
self.mode = 'max' | |
else: | |
self.monitor_op = np.less | |
self.best = np.Inf | |
self.mode = 'min' | |
if os.path.exists(f'{self.filepath}/best_valid.npy'): | |
self.best = np.load(f'{self.filepath}/best_valid.npy')[0] | |
def get_all_ckpts(self): | |
return sorted(glob.glob(f'{self.filepath}/{self.prefix}_ckpt_steps_*.ckpt'), | |
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) | |
def on_epoch_end(self, epoch, logs=None): | |
logs = logs or {} | |
self.epochs_since_last_check += 1 | |
best_filepath = f'{self.filepath}/{self.prefix}_ckpt_best.pt' | |
if self.epochs_since_last_check >= self.period: | |
self.epochs_since_last_check = 0 | |
filepath = f'{self.filepath}/{self.prefix}_ckpt_steps_{self.task.global_step}.ckpt' | |
if self.verbose > 0: | |
logging.info(f'Epoch {epoch:05d}@{self.task.global_step}: saving model to {filepath}') | |
self._save_model(filepath) | |
for old_ckpt in self.get_all_ckpts()[self.num_ckpt_keep:]: | |
# TODO: test filesystem calls | |
os.remove(old_ckpt) | |
# subprocess.check_call(f'del "{old_ckpt}"', shell=True) | |
if self.verbose > 0: | |
logging.info(f'Delete ckpt: {os.path.basename(old_ckpt)}') | |
current = logs.get(self.monitor) | |
if current is not None and self.save_best: | |
if self.monitor_op(current, self.best): | |
self.best = current | |
if self.verbose > 0: | |
logging.info( | |
f'Epoch {epoch:05d}@{self.task.global_step}: {self.monitor} reached' | |
f' {current:0.5f} (best {self.best:0.5f}), saving model to' | |
f' {best_filepath} as top 1') | |
self._save_model(best_filepath) | |
np.save(f'{self.filepath}/best_valid.npy', [self.best]) | |
def _save_model(self,path): | |
return self.save_function(path) | |
class BaseTrainer: | |
def __init__( | |
self, | |
logger=True, | |
checkpoint_callback=True, | |
default_save_path=None, | |
gradient_clip_val=0, | |
process_position=0, | |
gpus=-1, | |
log_gpu_memory=None, | |
show_progress_bar=True, | |
track_grad_norm=-1, | |
check_val_every_n_epoch=1, | |
accumulate_grad_batches=1, | |
max_updates=1000, | |
min_epochs=1, | |
val_check_interval=1.0, | |
log_save_interval=100, | |
row_log_interval=10, | |
print_nan_grads=False, | |
weights_summary='full', | |
num_sanity_val_steps=5, | |
resume_from_checkpoint=None, | |
): | |
self.log_gpu_memory = log_gpu_memory | |
self.gradient_clip_val = gradient_clip_val | |
self.check_val_every_n_epoch = check_val_every_n_epoch | |
self.track_grad_norm = track_grad_norm | |
self.on_gpu = True if (gpus and torch.cuda.is_available()) else False | |
self.process_position = process_position | |
self.weights_summary = weights_summary | |
self.max_updates = max_updates | |
self.min_epochs = min_epochs | |
self.num_sanity_val_steps = num_sanity_val_steps | |
self.print_nan_grads = print_nan_grads | |
self.resume_from_checkpoint = resume_from_checkpoint | |
self.default_save_path = default_save_path | |
# training bookeeping | |
self.total_batch_idx = 0 | |
self.running_loss = [] | |
self.avg_loss = 0 | |
self.batch_idx = 0 | |
self.tqdm_metrics = {} | |
self.callback_metrics = {} | |
self.num_val_batches = 0 | |
self.num_training_batches = 0 | |
self.num_test_batches = 0 | |
self.get_train_dataloader = None | |
self.get_test_dataloaders = None | |
self.get_val_dataloaders = None | |
self.is_iterable_train_dataloader = False | |
# training state | |
self.model = None | |
self.testing = False | |
self.disable_validation = False | |
self.lr_schedulers = [] | |
self.optimizers = None | |
self.global_step = 0 | |
self.current_epoch = 0 | |
self.total_batches = 0 | |
# configure checkpoint callback | |
self.checkpoint_callback = checkpoint_callback | |
self.checkpoint_callback.save_function = self.save_checkpoint | |
self.weights_save_path = self.checkpoint_callback.filepath | |
# accumulated grads | |
self.configure_accumulated_gradients(accumulate_grad_batches) | |
# allow int, string and gpu list | |
self.data_parallel_device_ids = [ | |
int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",") if x != ''] | |
if len(self.data_parallel_device_ids) == 0: | |
self.root_gpu = None | |
self.on_gpu = False | |
else: | |
self.root_gpu = self.data_parallel_device_ids[0] | |
self.on_gpu = True | |
# distributed backend choice | |
self.use_ddp = False | |
self.use_dp = False | |
self.single_gpu = False | |
self.distributed_backend = 'ddp' if self.num_gpus > 0 else 'dp' | |
self.set_distributed_mode(self.distributed_backend) | |
self.proc_rank = 0 | |
self.world_size = 1 | |
self.node_rank = 0 | |
# can't init progress bar here because starting a new process | |
# means the progress_bar won't survive pickling | |
self.show_progress_bar = show_progress_bar | |
# logging | |
self.log_save_interval = log_save_interval | |
self.val_check_interval = val_check_interval | |
self.logger = logger | |
self.logger.rank = 0 | |
self.row_log_interval = row_log_interval | |
def num_gpus(self): | |
gpus = self.data_parallel_device_ids | |
if gpus is None: | |
return 0 | |
else: | |
return len(gpus) | |
def data_parallel(self): | |
return self.use_dp or self.use_ddp | |
def get_model(self): | |
is_dp_module = isinstance(self.model, (DDP, DP)) | |
model = self.model.module if is_dp_module else self.model | |
return model | |
# ----------------------------- | |
# MODEL TRAINING | |
# ----------------------------- | |
def fit(self, model): | |
if self.use_ddp: | |
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,)) | |
else: | |
model.model = model.build_model() | |
if not self.testing: | |
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) | |
if self.use_dp: | |
model.cuda(self.root_gpu) | |
model = DP(model, device_ids=self.data_parallel_device_ids) | |
elif self.single_gpu: | |
model.cuda(self.root_gpu) | |
self.run_pretrain_routine(model) | |
return 1 | |
def init_optimizers(self, optimizers): | |
# single optimizer | |
if isinstance(optimizers, Optimizer): | |
return [optimizers], [] | |
# two lists | |
elif len(optimizers) == 2 and isinstance(optimizers[0], list): | |
optimizers, lr_schedulers = optimizers | |
return optimizers, lr_schedulers | |
# single list or tuple | |
elif isinstance(optimizers, list) or isinstance(optimizers, tuple): | |
return optimizers, [] | |
def run_pretrain_routine(self, model): | |
"""Sanity check a few things before starting actual training. | |
:param model: | |
""" | |
ref_model = model | |
if self.data_parallel: | |
ref_model = model.module | |
# give model convenience properties | |
ref_model.trainer = self | |
# set local properties on the model | |
self.copy_trainer_model_properties(ref_model) | |
# link up experiment object | |
if self.logger is not None: | |
ref_model.logger = self.logger | |
self.logger.save() | |
if self.use_ddp: | |
dist.barrier() | |
# set up checkpoint callback | |
# self.configure_checkpoint_callback() | |
# transfer data loaders from model | |
self.get_dataloaders(ref_model) | |
# track model now. | |
# if cluster resets state, the model will update with the saved weights | |
self.model = model | |
# restore training and model before hpc call | |
self.restore_weights(model) | |
# when testing requested only run test and return | |
if self.testing: | |
self.run_evaluation(test=True) | |
return | |
# check if we should run validation during training | |
self.disable_validation = self.num_val_batches == 0 | |
# run tiny validation (if validation defined) | |
# to make sure program won't crash during val | |
ref_model.on_sanity_check_start() | |
ref_model.on_train_start() | |
if not self.disable_validation and self.num_sanity_val_steps > 0: | |
# init progress bars for validation sanity check | |
pbar = tqdm.tqdm(desc='Validation sanity check', | |
total=self.num_sanity_val_steps * len(self.get_val_dataloaders()), | |
leave=False, position=2 * self.process_position, | |
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch') | |
self.main_progress_bar = pbar | |
# dummy validation progress bar | |
self.val_progress_bar = tqdm.tqdm(disable=True) | |
self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing) | |
# close progress bars | |
self.main_progress_bar.close() | |
self.val_progress_bar.close() | |
# init progress bar | |
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position, | |
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch', | |
file=sys.stdout) | |
self.main_progress_bar = pbar | |
# clear cache before training | |
if self.on_gpu: | |
torch.cuda.empty_cache() | |
# CORE TRAINING LOOP | |
self.train() | |
def test(self, model): | |
self.testing = True | |
self.fit(model) | |
def training_tqdm_dict(self): | |
tqdm_dict = { | |
'step': '{}'.format(self.global_step), | |
} | |
tqdm_dict.update(self.tqdm_metrics) | |
return tqdm_dict | |
# -------------------- | |
# restore ckpt | |
# -------------------- | |
def restore_weights(self, model): | |
""" | |
To restore weights we have two cases. | |
First, attempt to restore hpc weights. If successful, don't restore | |
other weights. | |
Otherwise, try to restore actual weights | |
:param model: | |
:return: | |
""" | |
# clear cache before restore | |
if self.on_gpu: | |
torch.cuda.empty_cache() | |
if self.resume_from_checkpoint is not None: | |
self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu) | |
else: | |
# restore weights if same exp version | |
self.restore_state_if_checkpoint_exists(model) | |
# wait for all models to restore weights | |
if self.use_ddp: | |
# wait for all processes to catch up | |
dist.barrier() | |
# clear cache after restore | |
if self.on_gpu: | |
torch.cuda.empty_cache() | |
def restore_state_if_checkpoint_exists(self, model): | |
did_restore = False | |
# do nothing if there's not dir or callback | |
no_ckpt_callback = (self.checkpoint_callback is None) or (not self.checkpoint_callback) | |
if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath): | |
return did_restore | |
# restore trainer state and model if there is a weight for this experiment | |
last_steps = -1 | |
last_ckpt_name = None | |
# find last epoch | |
checkpoints = os.listdir(self.checkpoint_callback.filepath) | |
for name in checkpoints: | |
if '.ckpt' in name and not name.endswith('part'): | |
if 'steps_' in name: | |
steps = name.split('steps_')[1] | |
steps = int(re.sub('[^0-9]', '', steps)) | |
if steps > last_steps: | |
last_steps = steps | |
last_ckpt_name = name | |
# restore last checkpoint | |
if last_ckpt_name is not None: | |
last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name) | |
self.restore(last_ckpt_path, self.on_gpu) | |
logging.info(f'model and trainer restored from checkpoint: {last_ckpt_path}') | |
did_restore = True | |
return did_restore | |
def restore(self, checkpoint_path, on_gpu): | |
checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
# load model state | |
model = self.get_model() | |
# load the state_dict on the model automatically | |
model.load_state_dict(checkpoint['state_dict'], strict=False) | |
if on_gpu: | |
model.cuda(self.root_gpu) | |
# load training state (affects trainer only) | |
self.restore_training_state(checkpoint) | |
model.global_step = self.global_step | |
del checkpoint | |
try: | |
if dist.is_initialized() and dist.get_rank() > 0: | |
return | |
except Exception as e: | |
print(e) | |
return | |
def restore_training_state(self, checkpoint): | |
""" | |
Restore trainer state. | |
Model will get its change to update | |
:param checkpoint: | |
:return: | |
""" | |
if self.checkpoint_callback is not None and self.checkpoint_callback is not False: | |
# return allowing checkpoints with meta information (global_step, etc) | |
self.checkpoint_callback.best = checkpoint['checkpoint_callback_best'] | |
self.global_step = checkpoint['global_step'] | |
self.current_epoch = checkpoint['epoch'] | |
if self.testing: | |
return | |
# restore the optimizers | |
optimizer_states = checkpoint['optimizer_states'] | |
for optimizer, opt_state in zip(self.optimizers, optimizer_states): | |
if optimizer is None: | |
return | |
optimizer.load_state_dict(opt_state) | |
# move optimizer to GPU 1 weight at a time | |
# avoids OOM | |
if self.root_gpu is not None: | |
for state in optimizer.state.values(): | |
for k, v in state.items(): | |
if isinstance(v, torch.Tensor): | |
state[k] = v.cuda(self.root_gpu) | |
# restore the lr schedulers | |
lr_schedulers = checkpoint['lr_schedulers'] | |
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers): | |
scheduler.load_state_dict(lrs_state) | |
# -------------------- | |
# MODEL SAVE CHECKPOINT | |
# -------------------- | |
def _atomic_save(self, checkpoint, filepath): | |
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. | |
This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once | |
saving is finished. | |
Args: | |
checkpoint (object): The object to save. | |
Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save`` | |
accepts. | |
filepath (str|pathlib.Path): The path to which the checkpoint will be saved. | |
This points to the file that the checkpoint will be stored in. | |
""" | |
tmp_path = str(filepath) + ".part" | |
torch.save(checkpoint, tmp_path) | |
os.replace(tmp_path, filepath) | |
def save_checkpoint(self, filepath): | |
checkpoint = self.dump_checkpoint() | |
self._atomic_save(checkpoint, filepath) | |
def dump_checkpoint(self): | |
checkpoint = { | |
'epoch': self.current_epoch, | |
'global_step': self.global_step | |
} | |
if self.checkpoint_callback is not None and self.checkpoint_callback is not False: | |
checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best | |
# save optimizers | |
optimizer_states = [] | |
for i, optimizer in enumerate(self.optimizers): | |
if optimizer is not None: | |
optimizer_states.append(optimizer.state_dict()) | |
checkpoint['optimizer_states'] = optimizer_states | |
# save lr schedulers | |
lr_schedulers = [] | |
for i, scheduler in enumerate(self.lr_schedulers): | |
lr_schedulers.append(scheduler.state_dict()) | |
checkpoint['lr_schedulers'] = lr_schedulers | |
# add the hparams and state_dict from the model | |
model = self.get_model() | |
checkpoint['state_dict'] = model.state_dict() | |
# give the model a chance to add a few things | |
model.on_save_checkpoint(checkpoint) | |
return checkpoint | |
def copy_trainer_model_properties(self, model): | |
if isinstance(model, DP): | |
ref_model = model.module | |
elif isinstance(model, DDP): | |
ref_model = model.module | |
else: | |
ref_model = model | |
for m in [model, ref_model]: | |
m.trainer = self | |
m.on_gpu = self.on_gpu | |
m.use_dp = self.use_dp | |
m.use_ddp = self.use_ddp | |
m.testing = self.testing | |
m.single_gpu = self.single_gpu | |
def transfer_batch_to_gpu(self, batch, gpu_id): | |
# base case: object can be directly moved using `cuda` or `to` | |
if callable(getattr(batch, 'cuda', None)): | |
return batch.cuda(gpu_id, non_blocking=True) | |
elif callable(getattr(batch, 'to', None)): | |
return batch.to(torch.device('cuda', gpu_id), non_blocking=True) | |
# when list | |
elif isinstance(batch, list): | |
for i, x in enumerate(batch): | |
batch[i] = self.transfer_batch_to_gpu(x, gpu_id) | |
return batch | |
# when tuple | |
elif isinstance(batch, tuple): | |
batch = list(batch) | |
for i, x in enumerate(batch): | |
batch[i] = self.transfer_batch_to_gpu(x, gpu_id) | |
return tuple(batch) | |
# when dict | |
elif isinstance(batch, dict): | |
for k, v in batch.items(): | |
batch[k] = self.transfer_batch_to_gpu(v, gpu_id) | |
return batch | |
# nothing matches, return the value as is without transform | |
return batch | |
def set_distributed_mode(self, distributed_backend): | |
# skip for CPU | |
if self.num_gpus == 0: | |
return | |
# single GPU case | |
# in single gpu case we allow ddp so we can train on multiple | |
# nodes, 1 gpu per node | |
elif self.num_gpus == 1: | |
self.single_gpu = True | |
self.use_dp = False | |
self.use_ddp = False | |
self.root_gpu = 0 | |
self.data_parallel_device_ids = [0] | |
else: | |
if distributed_backend is not None: | |
self.use_dp = distributed_backend == 'dp' | |
self.use_ddp = distributed_backend == 'ddp' | |
elif distributed_backend is None: | |
self.use_dp = True | |
self.use_ddp = False | |
logging.info(f'gpu available: {torch.cuda.is_available()}, used: {self.on_gpu}') | |
def ddp_train(self, gpu_idx, model): | |
""" | |
Entry point into a DP thread | |
:param gpu_idx: | |
:param model: | |
:param cluster_obj: | |
:return: | |
""" | |
# otherwise default to node rank 0 | |
self.node_rank = 0 | |
# show progressbar only on progress_rank 0 | |
self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_idx == 0 | |
# determine which process we are and world size | |
if self.use_ddp: | |
self.proc_rank = self.node_rank * self.num_gpus + gpu_idx | |
self.world_size = self.num_gpus | |
# let the exp know the rank to avoid overwriting logs | |
if self.logger is not None: | |
self.logger.rank = self.proc_rank | |
# set up server using proc 0's ip address | |
# try to init for 20 times at max in case ports are taken | |
# where to store ip_table | |
model.trainer = self | |
model.init_ddp_connection(self.proc_rank, self.world_size) | |
# CHOOSE OPTIMIZER | |
# allow for lr schedulers as well | |
model.model = model.build_model() | |
if not self.testing: | |
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) | |
# MODEL | |
# copy model to each gpu | |
if self.distributed_backend == 'ddp': | |
torch.cuda.set_device(gpu_idx) | |
model.cuda(gpu_idx) | |
# set model properties before going into wrapper | |
self.copy_trainer_model_properties(model) | |
# override root GPU | |
self.root_gpu = gpu_idx | |
if self.distributed_backend == 'ddp': | |
device_ids = [gpu_idx] | |
else: | |
device_ids = None | |
# allow user to configure ddp | |
model = model.configure_ddp(model, device_ids) | |
# continue training routine | |
self.run_pretrain_routine(model) | |
def resolve_root_node_address(self, root_node): | |
if '[' in root_node: | |
name = root_node.split('[')[0] | |
number = root_node.split(',')[0] | |
if '-' in number: | |
number = number.split('-')[0] | |
number = re.sub('[^0-9]', '', number) | |
root_node = name + number | |
return root_node | |
def log_metrics(self, metrics, grad_norm_dic, step=None): | |
"""Logs the metric dict passed in. | |
:param metrics: | |
:param grad_norm_dic: | |
""" | |
# added metrics by Lightning for convenience | |
metrics['epoch'] = self.current_epoch | |
# add norms | |
metrics.update(grad_norm_dic) | |
# turn all tensors to scalars | |
scalar_metrics = self.metrics_to_scalars(metrics) | |
step = step if step is not None else self.global_step | |
# log actual metrics | |
if self.proc_rank == 0 and self.logger is not None: | |
self.logger.log_metrics(scalar_metrics, step=step) | |
self.logger.save() | |
def add_tqdm_metrics(self, metrics): | |
for k, v in metrics.items(): | |
if type(v) is torch.Tensor: | |
v = v.item() | |
self.tqdm_metrics[k] = v | |
def metrics_to_scalars(self, metrics): | |
new_metrics = {} | |
for k, v in metrics.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.item() | |
if type(v) is dict: | |
v = self.metrics_to_scalars(v) | |
new_metrics[k] = v | |
return new_metrics | |
def process_output(self, output, train=False): | |
"""Reduces output according to the training mode. | |
Separates loss from logging and tqdm metrics | |
:param output: | |
:return: | |
""" | |
# --------------- | |
# EXTRACT CALLBACK KEYS | |
# --------------- | |
# all keys not progress_bar or log are candidates for callbacks | |
callback_metrics = {} | |
for k, v in output.items(): | |
if k not in ['progress_bar', 'log', 'hiddens']: | |
callback_metrics[k] = v | |
if train and self.use_dp: | |
num_gpus = self.num_gpus | |
callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus) | |
for k, v in callback_metrics.items(): | |
if isinstance(v, torch.Tensor): | |
callback_metrics[k] = v.item() | |
# --------------- | |
# EXTRACT PROGRESS BAR KEYS | |
# --------------- | |
try: | |
progress_output = output['progress_bar'] | |
# reduce progress metrics for tqdm when using dp | |
if train and self.use_dp: | |
num_gpus = self.num_gpus | |
progress_output = self.reduce_distributed_output(progress_output, num_gpus) | |
progress_bar_metrics = progress_output | |
except Exception: | |
progress_bar_metrics = {} | |
# --------------- | |
# EXTRACT LOGGING KEYS | |
# --------------- | |
# extract metrics to log to experiment | |
try: | |
log_output = output['log'] | |
# reduce progress metrics for tqdm when using dp | |
if train and self.use_dp: | |
num_gpus = self.num_gpus | |
log_output = self.reduce_distributed_output(log_output, num_gpus) | |
log_metrics = log_output | |
except Exception: | |
log_metrics = {} | |
# --------------- | |
# EXTRACT LOSS | |
# --------------- | |
# if output dict doesn't have the keyword loss | |
# then assume the output=loss if scalar | |
loss = None | |
if train: | |
try: | |
loss = output['loss'] | |
except Exception: | |
if type(output) is torch.Tensor: | |
loss = output | |
else: | |
raise RuntimeError( | |
'No `loss` value in the dictionary returned from `model.training_step()`.' | |
) | |
# when using dp need to reduce the loss | |
if self.use_dp: | |
loss = self.reduce_distributed_output(loss, self.num_gpus) | |
# --------------- | |
# EXTRACT HIDDEN | |
# --------------- | |
hiddens = output.get('hiddens') | |
# use every metric passed in as a candidate for callback | |
callback_metrics.update(progress_bar_metrics) | |
callback_metrics.update(log_metrics) | |
# convert tensors to numpy | |
for k, v in callback_metrics.items(): | |
if isinstance(v, torch.Tensor): | |
callback_metrics[k] = v.item() | |
return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens | |
def reduce_distributed_output(self, output, num_gpus): | |
if num_gpus <= 1: | |
return output | |
# when using DP, we get one output per gpu | |
# average outputs and return | |
if type(output) is torch.Tensor: | |
return output.mean() | |
for k, v in output.items(): | |
# recurse on nested dics | |
if isinstance(output[k], dict): | |
output[k] = self.reduce_distributed_output(output[k], num_gpus) | |
# do nothing when there's a scalar | |
elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0: | |
pass | |
# reduce only metrics that have the same number of gpus | |
elif output[k].size(0) == num_gpus: | |
reduced = torch.mean(output[k]) | |
output[k] = reduced | |
return output | |
def clip_gradients(self): | |
if self.gradient_clip_val > 0: | |
model = self.get_model() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val) | |
def print_nan_gradients(self): | |
model = self.get_model() | |
for param in model.parameters(): | |
if (param.grad is not None) and torch.isnan(param.grad.float()).any(): | |
logging.info(param, param.grad) | |
def configure_accumulated_gradients(self, accumulate_grad_batches): | |
self.accumulate_grad_batches = None | |
if isinstance(accumulate_grad_batches, dict): | |
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches) | |
elif isinstance(accumulate_grad_batches, int): | |
schedule = {1: accumulate_grad_batches} | |
self.accumulation_scheduler = GradientAccumulationScheduler(schedule) | |
else: | |
raise TypeError("Gradient accumulation supports only int and dict types") | |
def get_dataloaders(self, model): | |
if not self.testing: | |
self.init_train_dataloader(model) | |
self.init_val_dataloader(model) | |
else: | |
self.init_test_dataloader(model) | |
if self.use_ddp: | |
dist.barrier() | |
if not self.testing: | |
self.get_train_dataloader() | |
self.get_val_dataloaders() | |
else: | |
self.get_test_dataloaders() | |
def init_train_dataloader(self, model): | |
self.fisrt_epoch = True | |
self.get_train_dataloader = model.train_dataloader | |
if isinstance(self.get_train_dataloader(), torch.utils.data.DataLoader): | |
self.num_training_batches = len(self.get_train_dataloader()) | |
self.num_training_batches = int(self.num_training_batches) | |
else: | |
self.num_training_batches = float('inf') | |
self.is_iterable_train_dataloader = True | |
if isinstance(self.val_check_interval, int): | |
self.val_check_batch = self.val_check_interval | |
else: | |
self._percent_range_check('val_check_interval') | |
self.val_check_batch = int(self.num_training_batches * self.val_check_interval) | |
self.val_check_batch = max(1, self.val_check_batch) | |
def init_val_dataloader(self, model): | |
self.get_val_dataloaders = model.val_dataloader | |
self.num_val_batches = 0 | |
if self.get_val_dataloaders() is not None: | |
if isinstance(self.get_val_dataloaders()[0], torch.utils.data.DataLoader): | |
self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders()) | |
self.num_val_batches = int(self.num_val_batches) | |
else: | |
self.num_val_batches = float('inf') | |
def init_test_dataloader(self, model): | |
self.get_test_dataloaders = model.test_dataloader | |
if self.get_test_dataloaders() is not None: | |
if isinstance(self.get_test_dataloaders()[0], torch.utils.data.DataLoader): | |
self.num_test_batches = sum(len(dataloader) for dataloader in self.get_test_dataloaders()) | |
self.num_test_batches = int(self.num_test_batches) | |
else: | |
self.num_test_batches = float('inf') | |
def evaluate(self, model, dataloaders, max_batches, test=False): | |
"""Run evaluation code. | |
:param model: PT model | |
:param dataloaders: list of PT dataloaders | |
:param max_batches: Scalar | |
:param test: boolean | |
:return: | |
""" | |
# enable eval mode | |
model.zero_grad() | |
model.eval() | |
# copy properties for forward overrides | |
self.copy_trainer_model_properties(model) | |
# disable gradients to save memory | |
torch.set_grad_enabled(False) | |
if test: | |
self.get_model().test_start() | |
# bookkeeping | |
outputs = [] | |
# run training | |
for dataloader_idx, dataloader in enumerate(dataloaders): | |
dl_outputs = [] | |
for batch_idx, batch in enumerate(dataloader): | |
if batch is None: # pragma: no cover | |
continue | |
# stop short when on fast_dev_run (sets max_batch=1) | |
if batch_idx >= max_batches: | |
break | |
# ----------------- | |
# RUN EVALUATION STEP | |
# ----------------- | |
output = self.evaluation_forward(model, | |
batch, | |
batch_idx, | |
dataloader_idx, | |
test) | |
# track outputs for collation | |
dl_outputs.append(output) | |
# batch done | |
if test: | |
self.test_progress_bar.update(1) | |
else: | |
self.val_progress_bar.update(1) | |
outputs.append(dl_outputs) | |
# with a single dataloader don't pass an array | |
if len(dataloaders) == 1: | |
outputs = outputs[0] | |
# give model a chance to do something with the outputs (and method defined) | |
model = self.get_model() | |
if test: | |
eval_results_ = model.test_end(outputs) | |
else: | |
eval_results_ = model.validation_end(outputs) | |
eval_results = eval_results_ | |
# enable train mode again | |
model.train() | |
# enable gradients to save memory | |
torch.set_grad_enabled(True) | |
return eval_results | |
def run_evaluation(self, test=False): | |
# when testing make sure user defined a test step | |
model = self.get_model() | |
model.on_pre_performance_check() | |
# select dataloaders | |
if test: | |
dataloaders = self.get_test_dataloaders() | |
max_batches = self.num_test_batches | |
else: | |
# val | |
dataloaders = self.get_val_dataloaders() | |
max_batches = self.num_val_batches | |
# init validation or test progress bar | |
# main progress bar will already be closed when testing so initial position is free | |
position = 2 * self.process_position + (not test) | |
desc = 'Testing' if test else 'Validating' | |
pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position, | |
disable=not self.show_progress_bar, dynamic_ncols=True, | |
unit='batch', file=sys.stdout) | |
setattr(self, f'{"test" if test else "val"}_progress_bar', pbar) | |
# run evaluation | |
eval_results = self.evaluate(self.model, | |
dataloaders, | |
max_batches, | |
test) | |
if eval_results is not None: | |
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output( | |
eval_results) | |
# add metrics to prog bar | |
self.add_tqdm_metrics(prog_bar_metrics) | |
# log metrics | |
self.log_metrics(log_metrics, {}) | |
# track metrics for callbacks | |
self.callback_metrics.update(callback_metrics) | |
# hook | |
model.on_post_performance_check() | |
# add model specific metrics | |
tqdm_metrics = self.training_tqdm_dict | |
if not test: | |
self.main_progress_bar.set_postfix(**tqdm_metrics) | |
# close progress bar | |
if test: | |
self.test_progress_bar.close() | |
else: | |
self.val_progress_bar.close() | |
# model checkpointing | |
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test: | |
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, | |
logs=self.callback_metrics) | |
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False): | |
# make dataloader_idx arg in validation_step optional | |
args = [batch, batch_idx] | |
# print(batch) | |
if test and len(self.get_test_dataloaders()) > 1: | |
args.append(dataloader_idx) | |
elif not test and len(self.get_val_dataloaders()) > 1: | |
args.append(dataloader_idx) | |
# handle DP, DDP forward | |
if self.use_ddp or self.use_dp: | |
output = model(*args) | |
return output | |
# single GPU | |
if self.single_gpu: | |
# for single GPU put inputs on gpu manually | |
root_gpu = 0 | |
if isinstance(self.data_parallel_device_ids, list): | |
root_gpu = self.data_parallel_device_ids[0] | |
batch = self.transfer_batch_to_gpu(batch, root_gpu) | |
args[0] = batch | |
# CPU | |
if test: | |
output = model.test_step(*args) | |
else: | |
output = model.validation_step(*args) | |
return output | |
def train(self): | |
model = self.get_model() | |
# run all epochs | |
for epoch in range(self.current_epoch, 1000000): | |
# set seed for distributed sampler (enables shuffling for each epoch) | |
if self.use_ddp and hasattr(self.get_train_dataloader().sampler, 'set_epoch'): | |
self.get_train_dataloader().sampler.set_epoch(epoch) | |
# get model | |
model = self.get_model() | |
# update training progress in trainer and model | |
model.current_epoch = epoch | |
self.current_epoch = epoch | |
total_val_batches = 0 | |
if not self.disable_validation: | |
# val can be checked multiple times in epoch | |
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 | |
val_checks_per_epoch = self.num_training_batches // self.val_check_batch | |
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 | |
total_val_batches = self.num_val_batches * val_checks_per_epoch | |
# total batches includes multiple val checks | |
self.total_batches = self.num_training_batches + total_val_batches | |
self.batch_loss_value = 0 # accumulated grads | |
if self.is_iterable_train_dataloader: | |
# for iterable train loader, the progress bar never ends | |
num_iterations = None | |
else: | |
num_iterations = self.total_batches | |
# reset progress bar | |
# .reset() doesn't work on disabled progress bar so we should check | |
desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else '' | |
self.main_progress_bar.set_description(desc) | |
# changing gradient according accumulation_scheduler | |
self.accumulation_scheduler.on_epoch_begin(epoch, self) | |
# ----------------- | |
# RUN TNG EPOCH | |
# ----------------- | |
self.run_training_epoch() | |
# update LR schedulers | |
if self.lr_schedulers is not None: | |
for lr_scheduler in self.lr_schedulers: | |
lr_scheduler.step(epoch=self.current_epoch) | |
self.main_progress_bar.close() | |
model.on_train_end() | |
if self.logger is not None: | |
self.logger.finalize("success") | |
def run_training_epoch(self): | |
# before epoch hook | |
if self.is_function_implemented('on_epoch_start'): | |
model = self.get_model() | |
model.on_epoch_start() | |
# run epoch | |
for batch_idx, batch in enumerate(self.get_train_dataloader()): | |
# stop epoch if we limited the number of training batches | |
if batch_idx >= self.num_training_batches: | |
break | |
self.batch_idx = batch_idx | |
model = self.get_model() | |
model.global_step = self.global_step | |
# --------------- | |
# RUN TRAIN STEP | |
# --------------- | |
output = self.run_training_batch(batch, batch_idx) | |
batch_result, grad_norm_dic, batch_step_metrics = output | |
# when returning -1 from train_step, we end epoch early | |
early_stop_epoch = batch_result == -1 | |
# --------------- | |
# RUN VAL STEP | |
# --------------- | |
should_check_val = ( | |
not self.disable_validation and self.global_step % self.val_check_batch == 0 and not self.fisrt_epoch) | |
self.fisrt_epoch = False | |
if should_check_val: | |
self.run_evaluation(test=self.testing) | |
# when logs should be saved | |
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch | |
if should_save_log: | |
if self.proc_rank == 0 and self.logger is not None: | |
self.logger.save() | |
# when metrics should be logged | |
should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch | |
if should_log_metrics: | |
# logs user requested information to logger | |
self.log_metrics(batch_step_metrics, grad_norm_dic) | |
self.global_step += 1 | |
self.total_batch_idx += 1 | |
# end epoch early | |
# stop when the flag is changed or we've gone past the amount | |
# requested in the batches | |
if early_stop_epoch: | |
break | |
if self.global_step > self.max_updates: | |
print("| Training end..") | |
exit() | |
# epoch end hook | |
if self.is_function_implemented('on_epoch_end'): | |
model = self.get_model() | |
model.on_epoch_end() | |
def run_training_batch(self, batch, batch_idx): | |
# track grad norms | |
grad_norm_dic = {} | |
# track all metrics for callbacks | |
all_callback_metrics = [] | |
# track metrics to log | |
all_log_metrics = [] | |
if batch is None: | |
return 0, grad_norm_dic, {} | |
# hook | |
if self.is_function_implemented('on_batch_start'): | |
model_ref = self.get_model() | |
response = model_ref.on_batch_start(batch) | |
if response == -1: | |
return -1, grad_norm_dic, {} | |
splits = [batch] | |
self.hiddens = None | |
for split_idx, split_batch in enumerate(splits): | |
self.split_idx = split_idx | |
# call training_step once per optimizer | |
for opt_idx, optimizer in enumerate(self.optimizers): | |
if optimizer is None: | |
continue | |
# make sure only the gradients of the current optimizer's paramaters are calculated | |
# in the training step to prevent dangling gradients in multiple-optimizer setup. | |
if len(self.optimizers) > 1: | |
for param in self.get_model().parameters(): | |
param.requires_grad = False | |
for group in optimizer.param_groups: | |
for param in group['params']: | |
param.requires_grad = True | |
# wrap the forward step in a closure so second order methods work | |
def optimizer_closure(): | |
# forward pass | |
output = self.training_forward( | |
split_batch, batch_idx, opt_idx, self.hiddens) | |
closure_loss = output[0] | |
progress_bar_metrics = output[1] | |
log_metrics = output[2] | |
callback_metrics = output[3] | |
self.hiddens = output[4] | |
if closure_loss is None: | |
return None | |
# accumulate loss | |
# (if accumulate_grad_batches = 1 no effect) | |
closure_loss = closure_loss / self.accumulate_grad_batches | |
# backward pass | |
model_ref = self.get_model() | |
if closure_loss.requires_grad: | |
model_ref.backward(closure_loss, optimizer) | |
# track metrics for callbacks | |
all_callback_metrics.append(callback_metrics) | |
# track progress bar metrics | |
self.add_tqdm_metrics(progress_bar_metrics) | |
all_log_metrics.append(log_metrics) | |
# insert after step hook | |
if self.is_function_implemented('on_after_backward'): | |
model_ref = self.get_model() | |
model_ref.on_after_backward() | |
return closure_loss | |
# calculate loss | |
loss = optimizer_closure() | |
if loss is None: | |
continue | |
# nan grads | |
if self.print_nan_grads: | |
self.print_nan_gradients() | |
# track total loss for logging (avoid mem leaks) | |
self.batch_loss_value += loss.item() | |
# gradient update with accumulated gradients | |
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: | |
# track gradient norms when requested | |
if batch_idx % self.row_log_interval == 0: | |
if self.track_grad_norm > 0: | |
model = self.get_model() | |
grad_norm_dic = model.grad_norm( | |
self.track_grad_norm) | |
# clip gradients | |
self.clip_gradients() | |
# calls .step(), .zero_grad() | |
# override function to modify this behavior | |
model = self.get_model() | |
model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx) | |
# calculate running loss for display | |
self.running_loss.append(self.batch_loss_value) | |
self.batch_loss_value = 0 | |
self.avg_loss = np.mean(self.running_loss[-100:]) | |
# activate batch end hook | |
if self.is_function_implemented('on_batch_end'): | |
model = self.get_model() | |
model.on_batch_end() | |
# update progress bar | |
self.main_progress_bar.update(1) | |
self.main_progress_bar.set_postfix(**self.training_tqdm_dict) | |
# collapse all metrics into one dict | |
all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} | |
# track all metrics for callbacks | |
self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()}) | |
return 0, grad_norm_dic, all_log_metrics | |
def training_forward(self, batch, batch_idx, opt_idx, hiddens): | |
""" | |
Handle forward for each training case (distributed, single gpu, etc...) | |
:param batch: | |
:param batch_idx: | |
:return: | |
""" | |
# --------------- | |
# FORWARD | |
# --------------- | |
# enable not needing to add opt_idx to training_step | |
args = [batch, batch_idx, opt_idx] | |
# distributed forward | |
if self.use_ddp or self.use_dp: | |
output = self.model(*args) | |
# single GPU forward | |
elif self.single_gpu: | |
gpu_id = 0 | |
if isinstance(self.data_parallel_device_ids, list): | |
gpu_id = self.data_parallel_device_ids[0] | |
batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id) | |
args[0] = batch | |
output = self.model.training_step(*args) | |
# CPU forward | |
else: | |
output = self.model.training_step(*args) | |
# allow any mode to define training_end | |
model_ref = self.get_model() | |
output_ = model_ref.training_end(output) | |
if output_ is not None: | |
output = output_ | |
# format and reduce outputs accordingly | |
output = self.process_output(output, train=True) | |
return output | |
# --------------- | |
# Utils | |
# --------------- | |
def is_function_implemented(self, f_name): | |
model = self.get_model() | |
f_op = getattr(model, f_name, None) | |
return callable(f_op) | |
def _percent_range_check(self, name): | |
value = getattr(self, name) | |
msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}." | |
if name == "val_check_interval": | |
msg += " If you want to disable validation set `val_percent_check` to 0.0 instead." | |
if not 0. <= value <= 1.: | |
raise ValueError(msg) | |