Spaces:
Runtime error
Runtime error
from typing import List, Optional, Sequence | |
from pathlib import Path | |
import hydra | |
from omegaconf import OmegaConf, DictConfig | |
from pytorch_lightning import ( | |
Callback, | |
LightningDataModule, | |
LightningModule, | |
Trainer, | |
seed_everything, | |
) | |
from pytorch_lightning.loggers import LightningLoggerBase | |
from src.utils import utils | |
log = utils.get_logger(__name__) | |
def last_modification_time(path): | |
"""Including files / directory 1-level below the path | |
""" | |
path = Path(path) | |
if path.is_file(): | |
return path.stat().st_mtime | |
elif path.is_dir(): | |
return max(child.stat().st_mtime for child in path.iterdir()) | |
else: | |
return None | |
def train(config: DictConfig) -> Optional[float]: | |
"""Contains training pipeline. | |
Instantiates all PyTorch Lightning objects from config. | |
Args: | |
config (DictConfig): Configuration composed by Hydra. | |
Returns: | |
Optional[float]: Metric score for hyperparameter optimization. | |
""" | |
# Set seed for random number generators in pytorch, numpy and python.random | |
if config.get("seed"): | |
seed_everything(config.seed, workers=True) | |
# We want to add fields to config so need to call OmegaConf.set_struct | |
OmegaConf.set_struct(config, False) | |
# Init lightning model | |
model: LightningModule = hydra.utils.instantiate(config.task, cfg=config, _recursive_=False) | |
datamodule: LightningDataModule = model._datamodule | |
# Init lightning callbacks | |
callbacks: List[Callback] = [] | |
if "callbacks" in config: | |
for _, cb_conf in config.callbacks.items(): | |
if cb_conf is not None and "_target_" in cb_conf: | |
log.info(f"Instantiating callback <{cb_conf._target_}>") | |
callbacks.append(hydra.utils.instantiate(cb_conf)) | |
# Init lightning loggers | |
logger: List[LightningLoggerBase] = [] | |
if "logger" in config: | |
for _, lg_conf in config.logger.items(): | |
if lg_conf is not None and "_target_" in lg_conf: | |
log.info(f"Instantiating logger <{lg_conf._target_}>") | |
logger.append(hydra.utils.instantiate(lg_conf)) | |
ckpt_cfg = {} | |
if config.get('resume'): | |
try: | |
checkpoint_path = Path(config.callbacks.model_checkpoint.dirpath) | |
if checkpoint_path.is_dir(): | |
last_ckpt = checkpoint_path / 'last.ckpt' | |
autosave_ckpt = checkpoint_path / '.pl_auto_save.ckpt' | |
if not (last_ckpt.exists() or autosave_ckpt.exists()): | |
raise FileNotFoundError("Resume requires either last.ckpt or .pl_autosave.ckpt") | |
if ((not last_ckpt.exists()) | |
or (autosave_ckpt.exists() | |
and last_modification_time(autosave_ckpt) > last_modification_time(last_ckpt))): | |
# autosave_ckpt = autosave_ckpt.replace(autosave_ckpt.with_name('.pl_auto_save_loaded.ckpt')) | |
checkpoint_path = autosave_ckpt | |
else: | |
checkpoint_path = last_ckpt | |
# DeepSpeed's checkpoint is a directory, not a file | |
if checkpoint_path.is_file() or checkpoint_path.is_dir(): | |
ckpt_cfg = {'ckpt_path': str(checkpoint_path)} | |
else: | |
log.info(f'Checkpoint file {str(checkpoint_path)} not found. Will start training from scratch') | |
except (KeyError, FileNotFoundError): | |
pass | |
# Configure ddp automatically | |
n_devices = config.trainer.get('devices', 1) | |
if isinstance(n_devices, Sequence): # trainer.devices could be [1, 3] for example | |
n_devices = len(n_devices) | |
if n_devices > 1 and config.trainer.get('strategy', None) is None: | |
config.trainer.strategy = dict( | |
_target_='pytorch_lightning.strategies.DDPStrategy', | |
find_unused_parameters=False, | |
gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations | |
) | |
# Init lightning trainer | |
log.info(f"Instantiating trainer <{config.trainer._target_}>") | |
trainer: Trainer = hydra.utils.instantiate( | |
config.trainer, callbacks=callbacks, logger=logger) | |
# Train the model | |
log.info("Starting training!") | |
trainer.fit(model=model, datamodule=datamodule, **ckpt_cfg) | |
# Evaluate model on test set, using the best model achieved during training | |
if config.get("test_after_training") and not config.trainer.get("fast_dev_run"): | |
log.info("Starting testing!") | |
trainer.test(model=model, datamodule=datamodule) | |
# Make sure everything closed properly | |
log.info("Finalizing!") | |
utils.finish( | |
config=config, | |
model=model, | |
datamodule=datamodule, | |
trainer=trainer, | |
callbacks=callbacks, | |
logger=logger, | |
) | |
# Print path to best checkpoint | |
if not config.trainer.get("fast_dev_run"): | |
log.info(f"Best model ckpt: {trainer.checkpoint_callback.best_model_path}") | |
# Return metric score for hyperparameter optimization | |
optimized_metric = config.get("optimized_metric") | |
if optimized_metric: | |
return trainer.callback_metrics[optimized_metric] | |