|
import os |
|
from functools import partial |
|
from typing import Union, List |
|
from pathlib import Path |
|
from datetime import datetime, timedelta |
|
|
|
from omegaconf import DictConfig |
|
from pprint import pprint |
|
import torch |
|
from accelerate.utils import LoggerType |
|
from accelerate import ( |
|
Accelerator, |
|
GradScalerKwargs, |
|
DistributedDataParallelKwargs, |
|
InitProcessGroupKwargs |
|
) |
|
|
|
from ..modules.ema import EMA |
|
from ..utils.logging import get_logger |
|
|
|
|
|
class ModelState: |
|
""" |
|
Handling logger and `hugging face` accelerate training |
|
|
|
features: |
|
- Mixed Precision |
|
- Gradient Scaler |
|
- Gradient Accumulation |
|
- Optimizer |
|
- EMA |
|
- Logger (default: python print) |
|
- Monitor (default: wandb, tensorboard) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
args, |
|
log_path_suffix: str = None, |
|
ignore_log=False, |
|
) -> None: |
|
self.args: DictConfig = args |
|
|
|
"""check valid""" |
|
mixed_precision = self.args.get("mixed_precision") |
|
|
|
mixed_precision = "no" if type(mixed_precision) == bool else mixed_precision |
|
split_batches = self.args.get("split_batches", False) |
|
gradient_accumulate_step = self.args.get("gradient_accumulate_step", 1) |
|
assert gradient_accumulate_step >= 1, f"except gradient_accumulate_step >= 1, get {gradient_accumulate_step}" |
|
|
|
"""create working space""" |
|
|
|
|
|
|
|
|
|
|
|
config_name_only = str(self.args.get("config")).split(".")[0] |
|
results_folder = self.args.get("results_path", None) |
|
if results_folder is None: |
|
|
|
self.results_path = Path("./workdir") |
|
else: |
|
|
|
self.results_path = Path(os.path.join(results_folder, self.args.get("edit_type"), )) |
|
|
|
|
|
|
|
if log_path_suffix is not None: |
|
self.results_path = self.results_path / log_path_suffix |
|
|
|
kwargs_handlers = [] |
|
"""mixed precision training""" |
|
if args.mixed_precision == "no": |
|
scaler_handler = GradScalerKwargs( |
|
init_scale=args.init_scale, |
|
growth_factor=args.growth_factor, |
|
backoff_factor=args.backoff_factor, |
|
growth_interval=args.growth_interval, |
|
enabled=True |
|
) |
|
kwargs_handlers.append(scaler_handler) |
|
|
|
"""distributed training""" |
|
ddp_handler = DistributedDataParallelKwargs( |
|
dim=0, |
|
broadcast_buffers=True, |
|
static_graph=False, |
|
bucket_cap_mb=25, |
|
find_unused_parameters=False, |
|
check_reduction=False, |
|
gradient_as_bucket_view=False |
|
) |
|
kwargs_handlers.append(ddp_handler) |
|
|
|
init_handler = InitProcessGroupKwargs(timeout=timedelta(seconds=1200)) |
|
kwargs_handlers.append(init_handler) |
|
|
|
"""init visualized tracker""" |
|
log_with = [] |
|
self.args.visual = False |
|
if args.use_wandb: |
|
log_with.append(LoggerType.WANDB) |
|
if args.tensorboard: |
|
log_with.append(LoggerType.TENSORBOARD) |
|
|
|
"""hugging face Accelerator""" |
|
self.accelerator = Accelerator( |
|
device_placement=True, |
|
split_batches=split_batches, |
|
mixed_precision=mixed_precision, |
|
gradient_accumulation_steps=args.gradient_accumulate_step, |
|
cpu=True if args.use_cpu else False, |
|
log_with=None if len(log_with) == 0 else log_with, |
|
project_dir=self.results_path / "vis", |
|
kwargs_handlers=kwargs_handlers, |
|
) |
|
|
|
"""logs""" |
|
if self.accelerator.is_local_main_process: |
|
|
|
self.results_path.mkdir(parents=True, exist_ok=True) |
|
if not ignore_log: |
|
now_time = datetime.now().strftime('%Y-%m-%d-%H-%M') |
|
|
|
|
|
|
|
|
|
|
|
print("==> command line args: ") |
|
print(args.cmd_args) |
|
print("==> yaml config args: ") |
|
print(args.yaml_config) |
|
|
|
print("\n***** Model State *****") |
|
if self.accelerator.distributed_type != "NO": |
|
print(f"-> Distributed Type: {self.accelerator.distributed_type}") |
|
print(f"-> Split Batch Size: {split_batches}, Total Batch Size: {self.actual_batch_size}") |
|
print(f"-> Mixed Precision: {mixed_precision}, AMP: {self.accelerator.native_amp}," |
|
f" Gradient Accumulate Step: {gradient_accumulate_step}") |
|
print(f"-> Weight dtype: {self.weight_dtype}") |
|
|
|
if self.accelerator.scaler_handler is not None and self.accelerator.scaler_handler.enabled: |
|
print(f"-> Enabled GradScaler: {self.accelerator.scaler_handler.to_kwargs()}") |
|
|
|
if args.use_wandb: |
|
print(f"-> Init trackers: 'wandb' ") |
|
self.args.visual = True |
|
self.__init_tracker(project_name="my_project", tags=None, entity="") |
|
|
|
print(f"-> Working Space: '{self.results_path}'") |
|
|
|
"""EMA""" |
|
self.use_ema = args.get('ema', False) |
|
self.ema_wrapper = self.__build_ema_wrapper() |
|
|
|
"""glob step""" |
|
self.step = 0 |
|
|
|
"""log process""" |
|
self.accelerator.wait_for_everyone() |
|
print(f'Process {self.accelerator.process_index} using device: {self.accelerator.device}') |
|
|
|
self.print("-> state initialization complete \n") |
|
|
|
def __init_tracker(self, project_name, tags, entity): |
|
self.accelerator.init_trackers( |
|
project_name=project_name, |
|
config=dict(self.args), |
|
init_kwargs={ |
|
"wandb": { |
|
"notes": "accelerate trainer pipeline", |
|
"tags": [ |
|
f"total batch_size: {self.actual_batch_size}" |
|
], |
|
"entity": entity, |
|
}} |
|
) |
|
|
|
def __build_ema_wrapper(self): |
|
if self.use_ema: |
|
self.print(f"-> EMA: {self.use_ema}, decay: {self.args.ema_decay}, " |
|
f"update_after_step: {self.args.ema_update_after_step}, " |
|
f"update_every: {self.args.ema_update_every}") |
|
ema_wrapper = partial( |
|
EMA, beta=self.args.ema_decay, |
|
update_after_step=self.args.ema_update_after_step, |
|
update_every=self.args.ema_update_every |
|
) |
|
else: |
|
ema_wrapper = None |
|
|
|
return ema_wrapper |
|
|
|
@property |
|
def device(self): |
|
return self.accelerator.device |
|
|
|
@property |
|
def weight_dtype(self): |
|
weight_dtype = torch.float32 |
|
if self.accelerator.mixed_precision == "fp16": |
|
weight_dtype = torch.float16 |
|
elif self.accelerator.mixed_precision == "bf16": |
|
weight_dtype = torch.bfloat16 |
|
return weight_dtype |
|
|
|
@property |
|
def actual_batch_size(self): |
|
if self.accelerator.split_batches is False: |
|
actual_batch_size = self.args.batch_size * self.accelerator.num_processes * self.accelerator.gradient_accumulation_steps |
|
else: |
|
assert self.actual_batch_size % self.accelerator.num_processes == 0 |
|
actual_batch_size = self.args.batch_size |
|
return actual_batch_size |
|
|
|
@property |
|
def n_gpus(self): |
|
return self.accelerator.num_processes |
|
|
|
@property |
|
def no_decay_params_names(self): |
|
no_decay = [ |
|
"bn", "LayerNorm", "GroupNorm", |
|
] |
|
return no_decay |
|
|
|
def no_decay_params(self, model, weight_decay): |
|
"""optimization tricks""" |
|
optimizer_grouped_parameters = [ |
|
{ |
|
"params": [ |
|
p for n, p in model.named_parameters() |
|
if not any(nd in n for nd in self.no_decay_params_names) |
|
], |
|
"weight_decay": weight_decay, |
|
}, |
|
{ |
|
"params": [ |
|
p for n, p in model.named_parameters() |
|
if any(nd in n for nd in self.no_decay_params_names) |
|
], |
|
"weight_decay": 0.0, |
|
}, |
|
] |
|
return optimizer_grouped_parameters |
|
|
|
def optimized_params(self, model: torch.nn.Module, verbose=True) -> List: |
|
"""return parameters if `requires_grad` is True |
|
|
|
Args: |
|
model: pytorch models |
|
verbose: log optimized parameters |
|
|
|
Examples: |
|
>>> self.params_optimized = self.optimized_params(uvit, verbose=True) |
|
>>> optimizer = torch.optim.AdamW(self.params_optimized, lr=args.lr) |
|
|
|
Returns: |
|
a list of parameters |
|
""" |
|
params_optimized = [] |
|
for key, value in model.named_parameters(): |
|
if value.requires_grad: |
|
params_optimized.append(value) |
|
if verbose: |
|
self.print("\t {}, {}, {}".format(key, value.numel(), value.shape)) |
|
return params_optimized |
|
|
|
def save_everything(self, fpath: str): |
|
"""Saving and loading the model, optimizer, RNG generators, and the GradScaler.""" |
|
if not self.accelerator.is_main_process: |
|
return |
|
self.accelerator.save_state(fpath) |
|
|
|
def load_save_everything(self, fpath: str): |
|
"""Loading the model, optimizer, RNG generators, and the GradScaler.""" |
|
self.accelerator.load_state(fpath) |
|
|
|
def save(self, milestone: Union[str, float, int], checkpoint: object) -> None: |
|
if not self.accelerator.is_main_process: |
|
return |
|
|
|
torch.save(checkpoint, self.results_path / f'model-{milestone}.pt') |
|
|
|
def save_in(self, root: Union[str, Path], checkpoint: object) -> None: |
|
if not self.accelerator.is_main_process: |
|
return |
|
|
|
torch.save(checkpoint, root) |
|
|
|
def load_ckpt_model_only(self, model: torch.nn.Module, path: Union[str, Path], rm_module_prefix: bool = False): |
|
ckpt = torch.load(path, map_location=self.accelerator.device) |
|
|
|
unwrapped_model = self.accelerator.unwrap_model(model) |
|
if rm_module_prefix: |
|
unwrapped_model.load_state_dict({k.replace('module.', ''): v for k, v in ckpt.items()}) |
|
else: |
|
unwrapped_model.load_state_dict(ckpt) |
|
return unwrapped_model |
|
|
|
def load_shared_weights(self, model: torch.nn.Module, path: Union[str, Path]): |
|
ckpt = torch.load(path, map_location=self.accelerator.device) |
|
self.print(f"pretrained_dict len: {len(ckpt)}") |
|
unwrapped_model = self.accelerator.unwrap_model(model) |
|
model_dict = unwrapped_model.state_dict() |
|
pretrained_dict = {k: v for k, v in ckpt.items() if k in model_dict} |
|
model_dict.update(pretrained_dict) |
|
unwrapped_model.load_state_dict(model_dict, strict=False) |
|
self.print(f"selected pretrained_dict: {len(model_dict)}") |
|
return unwrapped_model |
|
|
|
def print(self, *args, **kwargs): |
|
"""Use in replacement of `print()` to only print once per server.""" |
|
self.accelerator.print(*args, **kwargs) |
|
|
|
def pretty_print(self, msg): |
|
if self.accelerator.is_local_main_process: |
|
pprint(dict(msg)) |
|
|
|
def close_tracker(self): |
|
self.accelerator.end_training() |
|
|
|
def free_memory(self): |
|
self.accelerator.clear() |
|
|
|
def close(self, msg: str = "Training complete."): |
|
"""Use in end of training.""" |
|
self.free_memory() |
|
|
|
if torch.cuda.is_available(): |
|
self.print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB') |
|
if self.args.visual: |
|
self.close_tracker() |
|
self.print(msg) |
|
|