| import torch
|
| import os
|
|
|
| from accelerate.utils import set_seed
|
| from omegaconf import open_dict, DictConfig
|
|
|
|
|
| def check_args_and_env(args: DictConfig) -> None:
|
| assert args.optim.batch_size % args.optim.grad_acc == 0
|
|
|
| assert args.eval.every_steps % args.logging.every_steps == 0
|
|
|
| if args.device == "gpu":
|
| assert torch.cuda.is_available(), "We use GPU to train/eval the model"
|
|
|
|
|
| def opti_flags(args: DictConfig) -> None:
|
|
|
| torch.backends.cuda.matmul.allow_tf32 = True
|
| torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
| def update_args_with_env_info(args: DictConfig) -> None:
|
| with open_dict(args):
|
| slurm_id = os.getenv("SLURM_JOB_ID")
|
|
|
| if slurm_id is not None:
|
| args.slurm_id = slurm_id
|
| else:
|
| args.slurm_id = "none"
|
|
|
| args.working_dir = os.getcwd()
|
|
|
|
|
| def setup_args(args: DictConfig) -> None:
|
| check_args_and_env(args)
|
| update_args_with_env_info(args)
|
| opti_flags(args)
|
|
|
| if args.seed is not None:
|
| set_seed(args.seed)
|
|
|