|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Entry point for dora to launch solvers for running training loops. |
|
See more info on how to use dora: https://github.com/facebookresearch/dora |
|
""" |
|
|
|
import logging |
|
import multiprocessing |
|
import os |
|
from pathlib import Path |
|
import sys |
|
import typing as tp |
|
|
|
from dora import git_save, hydra_main, XP |
|
import flashy |
|
import hydra |
|
import omegaconf |
|
|
|
from .environment import AudioCraftEnvironment |
|
from .utils.cluster import get_slurm_parameters |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def resolve_config_dset_paths(cfg): |
|
"""Enable Dora to load manifest from git clone repository.""" |
|
|
|
for key, value in cfg.datasource.items(): |
|
if isinstance(value, str): |
|
cfg.datasource[key] = git_save.to_absolute_path(value) |
|
|
|
|
|
def get_solver(cfg): |
|
from . import solvers |
|
|
|
assert cfg.dataset.batch_size % flashy.distrib.world_size() == 0 |
|
cfg.dataset.batch_size //= flashy.distrib.world_size() |
|
for split in ['train', 'valid', 'evaluate', 'generate']: |
|
if hasattr(cfg.dataset, split) and hasattr(cfg.dataset[split], 'batch_size'): |
|
assert cfg.dataset[split].batch_size % flashy.distrib.world_size() == 0 |
|
cfg.dataset[split].batch_size //= flashy.distrib.world_size() |
|
resolve_config_dset_paths(cfg) |
|
solver = solvers.get_solver(cfg) |
|
return solver |
|
|
|
|
|
def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, |
|
restore: bool = True, load_best: bool = True, |
|
ignore_state_keys: tp.List[str] = [], disable_fsdp: bool = True): |
|
"""Given a XP, return the Solver object. |
|
|
|
Args: |
|
xp (XP): Dora experiment for which to retrieve the solver. |
|
override_cfg (dict or None): If not None, should be a dict used to |
|
override some values in the config of `xp`. This will not impact |
|
the XP signature or folder. The format is different |
|
than the one used in Dora grids, nested keys should actually be nested dicts, |
|
not flattened, e.g. `{'optim': {'batch_size': 32}}`. |
|
restore (bool): If `True` (the default), restore state from the last checkpoint. |
|
load_best (bool): If `True` (the default), load the best state from the checkpoint. |
|
ignore_state_keys (list[str]): List of sources to ignore when loading the state, e.g. `optimizer`. |
|
disable_fsdp (bool): if True, disables FSDP entirely. This will |
|
also automatically skip loading the EMA. For solver specific |
|
state sources, like the optimizer, you might want to |
|
use along `ignore_state_keys=['optimizer']`. Must be used with `load_best=True`. |
|
""" |
|
logger.info(f"Loading solver from XP {xp.sig}. " |
|
f"Overrides used: {xp.argv}") |
|
cfg = xp.cfg |
|
if override_cfg is not None: |
|
cfg = omegaconf.OmegaConf.merge(cfg, omegaconf.DictConfig(override_cfg)) |
|
if disable_fsdp and cfg.fsdp.use: |
|
cfg.fsdp.use = False |
|
assert load_best is True |
|
|
|
|
|
|
|
|
|
|
|
ignore_state_keys = ignore_state_keys + ['model', 'ema', 'best_state'] |
|
|
|
try: |
|
with xp.enter(): |
|
solver = get_solver(cfg) |
|
if restore: |
|
solver.restore(load_best=load_best, ignore_state_keys=ignore_state_keys) |
|
return solver |
|
finally: |
|
hydra.core.global_hydra.GlobalHydra.instance().clear() |
|
|
|
|
|
def get_solver_from_sig(sig: str, *args, **kwargs): |
|
"""Return Solver object from Dora signature, i.e. to play with it from a notebook. |
|
See `get_solver_from_xp` for more information. |
|
""" |
|
xp = main.get_xp_from_sig(sig) |
|
return get_solver_from_xp(xp, *args, **kwargs) |
|
|
|
|
|
def init_seed_and_system(cfg): |
|
import numpy as np |
|
import torch |
|
import random |
|
from audiocraft.modules.transformer import set_efficient_attention_backend |
|
|
|
multiprocessing.set_start_method(cfg.mp_start_method) |
|
logger.debug('Setting mp start method to %s', cfg.mp_start_method) |
|
random.seed(cfg.seed) |
|
np.random.seed(cfg.seed) |
|
|
|
torch.manual_seed(cfg.seed) |
|
torch.set_num_threads(cfg.num_threads) |
|
os.environ['MKL_NUM_THREADS'] = str(cfg.num_threads) |
|
os.environ['OMP_NUM_THREADS'] = str(cfg.num_threads) |
|
logger.debug('Setting num threads to %d', cfg.num_threads) |
|
set_efficient_attention_backend(cfg.efficient_attention_backend) |
|
logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend) |
|
if 'SLURM_JOB_ID' in os.environ: |
|
tmpdir = Path('/scratch/slurm_tmpdir/' + os.environ['SLURM_JOB_ID']) |
|
if tmpdir.exists(): |
|
logger.info("Changing tmpdir to %s", tmpdir) |
|
os.environ['TMPDIR'] = str(tmpdir) |
|
|
|
|
|
@hydra_main(config_path='../config', config_name='config', version_base='1.1') |
|
def main(cfg): |
|
init_seed_and_system(cfg) |
|
|
|
|
|
log_name = '%s.log.{rank}' % cfg.execute_only if cfg.execute_only else 'solver.log.{rank}' |
|
flashy.setup_logging(level=str(cfg.logging.level).upper(), log_name=log_name) |
|
|
|
flashy.distrib.init() |
|
solver = get_solver(cfg) |
|
if cfg.show: |
|
solver.show() |
|
return |
|
|
|
if cfg.execute_only: |
|
assert cfg.execute_inplace or cfg.continue_from is not None, \ |
|
"Please explicitly specify the checkpoint to continue from with continue_from=<sig_or_path> " + \ |
|
"when running with execute_only or set execute_inplace to True." |
|
solver.restore(replay_metrics=False) |
|
solver.run_one_stage(cfg.execute_only) |
|
return |
|
|
|
return solver.run() |
|
|
|
|
|
main.dora.dir = AudioCraftEnvironment.get_dora_dir() |
|
main._base_cfg.slurm = get_slurm_parameters(main._base_cfg.slurm) |
|
|
|
if main.dora.shared is not None and not os.access(main.dora.shared, os.R_OK): |
|
print("No read permission on dora.shared folder, ignoring it.", file=sys.stderr) |
|
main.dora.shared = None |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|