Spaces:
Sleeping
Sleeping
""" | |
Misc functions, including distributed helpers, mostly from torchvision | |
""" | |
import glob | |
import math | |
import os | |
import time | |
import datetime | |
import random | |
from dataclasses import dataclass | |
from collections import defaultdict, deque | |
from typing import Callable, Optional | |
from PIL import Image | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import torchvision | |
from accelerate import Accelerator | |
from omegaconf import DictConfig | |
from configs.structured import ProjectConfig | |
class TrainState: | |
epoch: int = 0 | |
step: int = 0 | |
best_val: Optional[float] = None | |
def get_optimizer(cfg: ProjectConfig, model: torch.nn.Module, accelerator: Accelerator) -> torch.optim.Optimizer: | |
"""Gets optimizer from configs""" | |
# Determine the learning rate | |
if cfg.optimizer.scale_learning_rate_with_batch_size: | |
lr = accelerator.state.num_processes * cfg.dataloader.batch_size * cfg.optimizer.lr | |
print('lr = {ws} (num gpus) * {bs} (batch_size) * {blr} (base learning rate) = {lr}'.format( | |
ws=accelerator.state.num_processes, bs=cfg.dataloader.batch_size, blr=cfg.optimizer.lr, lr=lr)) | |
else: # scale base learning rate by batch size | |
lr = cfg.optimizer.lr | |
print('lr = {lr} (absolute learning rate)'.format(lr=lr)) | |
# Get optimizer parameters, excluding certain parameters from weight decay | |
no_decay = ["bias", "LayerNorm.weight"] | |
parameters = [ | |
{ | |
"params": [p for n, p in model.named_parameters() if p.requires_grad and not any(nd in n for nd in no_decay)], | |
"weight_decay": cfg.optimizer.weight_decay, | |
}, | |
{ | |
"params": [p for n, p in model.named_parameters() if p.requires_grad and any(nd in n for nd in no_decay)], | |
"weight_decay": 0.0, | |
}, | |
] | |
# Construct optimizer | |
if cfg.optimizer.type == 'torch': | |
Optimizer: torch.optim.Optimizer = getattr(torch.optim, cfg.optimizer.name) | |
optimizer = Optimizer(parameters, lr=lr, **cfg.optimizer.kwargs) | |
elif cfg.optimizer.type == 'timm': | |
from timm.optim import create_optimizer_v2 | |
optimizer = create_optimizer_v2(model_or_params=parameters, lr=lr, **cfg.optimizer.kwargs) | |
elif cfg.optimizer.type == 'transformers': | |
import transformers | |
Optimizer: torch.optim.Optimizer = getattr(transformers, cfg.optimizer.name) | |
optimizer = Optimizer(parameters, lr=lr, **cfg.optimizer.kwargs) | |
else: | |
raise NotImplementedError(f'Invalid optimizer configs: {cfg.optimizer}') | |
return optimizer | |
def get_scheduler(cfg: ProjectConfig, optimizer: torch.optim.Optimizer) -> Callable: | |
"""Gets scheduler from configs""" | |
# Get scheduler | |
if cfg.scheduler.type == 'torch': | |
Scheduler: torch.optim.lr_scheduler._LRScheduler = getattr(torch.optim.lr_scheduler, cfg.scheduler.type) | |
scheduler = Scheduler(optimizer=optimizer, **cfg.scheduler.kwargs) | |
if cfg.scheduler.get('warmup', 0): | |
from warmup_scheduler import GradualWarmupScheduler | |
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, | |
total_epoch=cfg.scheduler.warmup, after_scheduler=scheduler) | |
elif cfg.scheduler.type == 'timm': | |
from timm.scheduler import create_scheduler | |
scheduler, _ = create_scheduler(optimizer=optimizer, args=cfg.scheduler.kwargs) | |
elif cfg.scheduler.type == 'transformers': | |
from transformers import get_scheduler # default: linear scheduler without warm up and linear decay | |
scheduler = get_scheduler(optimizer=optimizer, **cfg.scheduler.kwargs) | |
else: | |
raise NotImplementedError(f'invalid scheduler configs: {cfg.scheduler}') | |
return scheduler | |
def accuracy(output, target, topk=(1,)): | |
"""Computes the accuracy over the k top predictions for the specified values of k""" | |
maxk = max(topk) | |
batch_size = target.size(0) | |
_, pred = output.topk(maxk, 1, True, True) | |
pred = pred.t() | |
correct = pred.eq(target.reshape(1, -1).expand_as(pred)) | |
res = [] | |
for k in topk: | |
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) | |
res.append(correct_k.mul_(100.0 / batch_size)) | |
return res | |
class SmoothedValue(object): | |
"""Track a series of values and provide access to smoothed values over a | |
window or the global series average. | |
""" | |
def __init__(self, window_size=20, fmt=None): | |
if fmt is None: | |
fmt = "{median:.4f} ({global_avg:.4f})" | |
self.deque = deque(maxlen=window_size) | |
self.total = 0.0 | |
self.count = 0 | |
self.fmt = fmt | |
def update(self, value, n=1): | |
self.deque.append(value) | |
self.count += n | |
self.total += value * n | |
def synchronize_between_processes(self, device='cuda'): | |
""" | |
Warning: does not synchronize the deque! | |
""" | |
if not using_distributed(): | |
return | |
t = torch.tensor([self.count, self.total], dtype=torch.float64, device=device) | |
dist.barrier() | |
dist.all_reduce(t) | |
t = t.tolist() | |
self.count = int(t[0]) | |
self.total = t[1] | |
def median(self): | |
d = torch.tensor(list(self.deque)) | |
return d.median().item() | |
def avg(self): | |
d = torch.tensor(list(self.deque), dtype=torch.float32) | |
return d.mean().item() | |
def global_avg(self): | |
return self.total / max(self.count, 1) | |
def max(self): | |
return max(self.deque) | |
def value(self): | |
return self.deque[-1] | |
def __str__(self): | |
return self.fmt.format( | |
median=self.median, | |
avg=self.avg, | |
global_avg=self.global_avg, | |
max=self.max, | |
value=self.value) if len(self.deque) > 0 else "" | |
class MetricLogger(object): | |
def __init__(self, delimiter="\t"): | |
self.meters = defaultdict(SmoothedValue) | |
self.delimiter = delimiter | |
def update(self, **kwargs): | |
n = kwargs.pop('n', 1) | |
for k, v in kwargs.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.item() | |
assert isinstance(v, (float, int)) | |
self.meters[k].update(v, n=n) | |
def __getattr__(self, attr): | |
if attr in self.meters: | |
return self.meters[attr] | |
if attr in self.__dict__: | |
return self.__dict__[attr] | |
raise AttributeError("'{}' object has no attribute '{}'".format( | |
type(self).__name__, attr)) | |
def __str__(self): | |
loss_str = [] | |
for name, meter in self.meters.items(): | |
loss_str.append( | |
"{}: {}".format(name, str(meter)) | |
) | |
return self.delimiter.join(loss_str) | |
def synchronize_between_processes(self, device='cuda'): | |
for meter in self.meters.values(): | |
meter.synchronize_between_processes(device=device) | |
def add_meter(self, name, meter): | |
self.meters[name] = meter | |
def log_every(self, iterable, print_freq, header=None): | |
i = 0 | |
if not header: | |
header = '' | |
start_time = time.time() | |
end = time.time() | |
iter_time = SmoothedValue(fmt='{avg:.4f}') | |
data_time = SmoothedValue(fmt='{avg:.4f}') | |
space_fmt = ':' + str(len(str(len(iterable)))) + 'd' | |
log_msg = [ | |
header, | |
'[{0' + space_fmt + '}/{1}]', | |
'eta: {eta}', | |
'{meters}', | |
'time: {time}', | |
'data: {data}' | |
] | |
if torch.cuda.is_available(): | |
log_msg.append('max mem: {memory:.0f}') | |
log_msg = self.delimiter.join(log_msg) | |
MB = 1024.0 * 1024.0 | |
for obj in iterable: | |
data_time.update(time.time() - end) | |
yield obj | |
iter_time.update(time.time() - end) | |
if i % print_freq == 0 or i == len(iterable) - 1: | |
eta_seconds = iter_time.global_avg * (len(iterable) - i) | |
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) | |
if torch.cuda.is_available(): | |
print(log_msg.format( | |
i, len(iterable), eta=eta_string, | |
meters=str(self), | |
time=str(iter_time), data=str(data_time), | |
memory=torch.cuda.max_memory_allocated() / MB)) | |
else: | |
print(log_msg.format( | |
i, len(iterable), eta=eta_string, | |
meters=str(self), | |
time=str(iter_time), data=str(data_time))) | |
i += 1 | |
end = time.time() | |
total_time = time.time() - start_time | |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
print('{} Total time: {} ({:.4f} s / it)'.format( | |
header, total_time_str, total_time / len(iterable))) | |
class NormalizeInverse(torchvision.transforms.Normalize): | |
""" | |
Undoes the normalization and returns the reconstructed images in the input domain. | |
""" | |
def __init__(self, mean, std): | |
mean = torch.as_tensor(mean) | |
std = torch.as_tensor(std) | |
std_inv = 1 / (std + 1e-7) | |
mean_inv = -mean * std_inv | |
super().__init__(mean=mean_inv, std=std_inv) | |
def __call__(self, tensor): | |
return super().__call__(tensor.clone()) | |
def resume_from_checkpoint(cfg: ProjectConfig, model, optimizer=None, scheduler=None, model_ema=None): | |
# Check if resuming training from a checkpoint | |
if not cfg.checkpoint.resume: | |
print('Starting training from scratch') | |
return TrainState() | |
# XH: find checkpiont path automatically | |
if not os.path.isfile(cfg.checkpoint.resume): | |
print(f"The given checkpoint path {cfg.checkpoint.resume} does not exist, trying to find one...") | |
# print(os.getcwd()) | |
ckpt_file = os.path.join(cfg.run.code_dir_abs, f'outputs/{cfg.run.name}/single/checkpoint-latest.pth') | |
if not os.path.isfile(ckpt_file): | |
# just get the fist dir, for backward compatibility | |
folders = sorted(glob.glob(os.path.join(cfg.run.code_dir_abs, f'outputs/{cfg.run.name}/2023-*'))) | |
assert len(folders) <= 1 | |
if len(folders) > 0: | |
ckpt_file = os.path.join(folders[0], 'checkpoint-latest.pth') | |
if os.path.isfile(ckpt_file): | |
print(f"Found checkpoint at {ckpt_file}!") | |
cfg.checkpoint.resume = ckpt_file | |
else: | |
print(f"No checkpoint found in outputs/{cfg.run.name}/single/!") | |
return TrainState() | |
# If resuming, load model state dict | |
print(f'Loading checkpoint ({datetime.datetime.now()})') | |
checkpoint = torch.load(cfg.checkpoint.resume, map_location='cpu') | |
if 'model' in checkpoint: | |
state_dict, key = checkpoint['model'], 'model' | |
else: | |
print("Warning: no model found in checkpoint!") | |
state_dict, key = checkpoint, 'N/A' | |
if any(k.startswith('module.') for k in state_dict.keys()): | |
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} | |
print('Removed "module." from checkpoint state dict') | |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
print(f'Loaded model checkpoint key {key} from {cfg.checkpoint.resume}') | |
if len(missing_keys): | |
print(f' - Missing_keys: {missing_keys}') | |
if len(unexpected_keys): | |
print(f' - Unexpected_keys: {unexpected_keys}') | |
# 298 missing, 328 unexpected! total 448 modules. | |
print(f"{len(missing_keys)} missing, {len(unexpected_keys)} unexpected! total {len(model.state_dict().keys())} modules.") | |
# print("First 10 keys:") | |
# for i in range(10): | |
# print(missing_keys[i], unexpected_keys[i]) | |
# exit(0) | |
if 'step' in checkpoint: | |
print("Number of trained steps:", checkpoint['step']) | |
# TODO: implement better loading for fine tuning | |
# Resume model ema | |
if cfg.ema.use_ema: | |
if checkpoint['model_ema']: | |
model_ema.load_state_dict(checkpoint['model_ema']) | |
print('Loaded model ema from checkpoint') | |
else: | |
model_ema.load_state_dict(model.parameters()) | |
print('No model ema in checkpoint; loaded current parameters into model') | |
else: | |
if 'model_ema' in checkpoint and checkpoint['model_ema']: | |
print('Not using model ema, but model_ema found in checkpoint (you probably want to resume it!)') | |
else: | |
print('Not using model ema, and no model_ema found in checkpoint.') | |
# Resume optimizer and/or training state | |
train_state = TrainState() | |
if 'train' in cfg.run.job: | |
if cfg.checkpoint.resume_training: | |
assert ( | |
cfg.checkpoint.resume_training_optimizer | |
or cfg.checkpoint.resume_training_scheduler | |
or cfg.checkpoint.resume_training_state | |
or cfg.checkpoint.resume_training | |
), f'Invalid configs: {cfg.checkpoint}' | |
if cfg.checkpoint.resume_training_optimizer: | |
if 'optimizer' not in checkpoint: | |
assert 'tune' in cfg.run.name, f'please check the checkpoint for run {cfg.run.name}' | |
print("Warning: not loading optimizer!") | |
else: | |
assert 'optimizer' in checkpoint, f'Value not in {checkpoint.keys()}' | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
print(f'Loaded optimizer from checkpoint') | |
else: | |
print(f'Did not load optimizer from checkpoint') | |
if cfg.checkpoint.resume_training_scheduler: | |
if 'scheduler' not in checkpoint: | |
assert 'tune' in cfg.run.name, f'please check the checkpoint for run {cfg.run.name}' | |
print("Warning: not loading scheduler!") | |
else: | |
assert 'scheduler' in checkpoint, f'Value not in {checkpoint.keys()}' | |
scheduler.load_state_dict(checkpoint['scheduler']) | |
print(f'Loaded scheduler from checkpoint') | |
else: | |
print(f'Did not load scheduler from checkpoint') | |
if cfg.checkpoint.resume_training_state: | |
if 'steps' in checkpoint and 'step' not in checkpoint: # fixes an old typo | |
checkpoint['step'] = checkpoint.pop('steps') | |
assert {'epoch', 'step', 'best_val'}.issubset(set(checkpoint.keys())) | |
epoch, step, best_val = checkpoint['epoch'] + 1, checkpoint['step'], checkpoint['best_val'] | |
train_state = TrainState(epoch=epoch, step=step, best_val=best_val) | |
print(f'Resumed state from checkpoint: step {step}, epoch {epoch}, best_val {best_val}') | |
else: | |
print(f'Did not load train state from checkpoint') | |
else: | |
print('Did not resume optimizer, scheduler, or epoch from checkpoint') | |
print(f'Finished loading checkpoint ({datetime.datetime.now()})') | |
return train_state | |
def setup_distributed_print(is_master): | |
""" | |
This function disables printing when not in master process | |
""" | |
import builtins as __builtin__ | |
from rich import print as __richprint__ | |
builtin_print = __richprint__ # __builtin__.print | |
def print(*args, **kwargs): | |
force = kwargs.pop('force', False) | |
if is_master or force: | |
builtin_print(*args, **kwargs) | |
__builtin__.print = print | |
def using_distributed(): | |
return dist.is_available() and dist.is_initialized() | |
def get_rank(): | |
return dist.get_rank() if using_distributed() else 0 | |
def set_seed(seed): | |
rank = get_rank() | |
seed = seed + rank | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
torch.backends.cudnn.enabled = True | |
torch.backends.cudnn.benchmark = True | |
if using_distributed(): | |
print(f'Seeding node {rank} with seed {seed}', force=True) | |
else: | |
print(f'Seeding node {rank} with seed {seed}') | |
def compute_grad_norm(parameters): | |
# total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2).item() | |
total_norm = 0 | |
for p in parameters: | |
if p.grad is not None and p.requires_grad: | |
param_norm = p.grad.detach().data.norm(2) | |
total_norm += param_norm.item() ** 2 | |
total_norm = total_norm ** 0.5 | |
return total_norm | |
class dotdict(dict): | |
"""dot.notation access to dictionary attributes""" | |
__getattr__ = dict.get | |
__setattr__ = dict.__setitem__ | |
__delattr__ = dict.__delitem__ | |