styledrop / utils.py
zideliu's picture
StyleDrop init
28c6826
import pickle
import torch
import torch.nn as nn
import numpy as np
import os
from tqdm import tqdm
from torchvision.utils import save_image
from torch import distributed as dist
from loguru import logger
logging = logger
def set_logger(log_level='info', fname=None):
import logging as _logging
handler = logging.get_absl_handler()
formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s')
handler.setFormatter(formatter)
logging.set_verbosity(log_level)
if fname is not None:
handler = _logging.FileHandler(fname)
handler.setFormatter(formatter)
logging.get_absl_logger().addHandler(handler)
def dct2str(dct):
return str({k: f'{v:.6g}' for k, v in dct.items()})
def get_nnet(name, **kwargs):
if name == 'uvit_t2i_vq':
from libs.uvit_t2i_vq import UViT
return UViT(**kwargs)
elif name == 'uvit_vq':
from libs.uvit_vq import UViT
return UViT(**kwargs)
else:
raise NotImplementedError(name)
def set_seed(seed: int):
if seed is not None:
torch.manual_seed(seed)
np.random.seed(seed)
def get_optimizer(params, name, **kwargs):
if name == 'adam':
from torch.optim import Adam
return Adam(params, **kwargs)
elif name == 'adamw':
from torch.optim import AdamW
return AdamW(params, **kwargs)
else:
raise NotImplementedError(name)
def customized_lr_scheduler(optimizer, warmup_steps=-1):
from torch.optim.lr_scheduler import LambdaLR
def fn(step):
if warmup_steps > 0:
return min(step / warmup_steps, 1)
else:
return 1
return LambdaLR(optimizer, fn)
def get_lr_scheduler(optimizer, name, **kwargs):
if name == 'customized':
return customized_lr_scheduler(optimizer, **kwargs)
else:
raise NotImplementedError(name)
def ema(model_dest: nn.Module, model_src: nn.Module, rate):
param_dict_src = dict(model_src.named_parameters())
for p_name, p_dest in model_dest.named_parameters():
p_src = param_dict_src[p_name]
assert p_src is not p_dest
if 'adapter' not in p_name:
p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
else:
p_dest.data = p_src.detach().clone()
class TrainState(object):
def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None):
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.step = step
self.nnet = nnet
self.nnet_ema = nnet_ema
def ema_update(self, rate=0.9999):
if self.nnet_ema is not None:
ema(self.nnet_ema, self.nnet, rate)
def save(self, path, adapter_only=False,name=""):
os.makedirs(path, exist_ok=True)
torch.save(self.step, os.path.join(path, 'step.pth'))
if adapter_only:
torch.save(self.nnet.adapter.state_dict(), os.path.join(path, name+'adapter.pth'))
else:
for key, val in self.__dict__.items():
if key != 'step' and val is not None:
torch.save(val.state_dict(), os.path.join(path, f'{key}.pth'))
def make_dict(self,model,state_dict):
state = {}
for k in model.state_dict().keys():
if k in state_dict:
state[k] = state_dict[k].clone()
else:
state[k] = model.state_dict()[k].clone()
return state
def load(self, path):
logging.info(f'load from {path}')
self.step = torch.load(os.path.join(path, 'step.pth'), map_location='cpu')
for key, val in self.__dict__.items():
if key != 'step' and val is not None and key != 'optimizer' and key != 'lr_scheduler':
if key == 'nnet' or key == 'nnet_ema':
val.load_state_dict(self.make_dict(val,torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu')))
else:
val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu'))
def load_adapter(self,path):
logging.info('load adapter from {}'.format(path))
adapter = torch.load(path,map_location='cpu')
keys=['nnet','nnet_ema']
for key in keys:
if key in self.__dict__:
self.__dict__[key].adapter.load_state_dict(adapter)
else:
logging.info('adapter not in state_dict')
def resume(self, ckpt_root,adapter_path=None, step=None):
if not os.path.exists(ckpt_root):
return
if ckpt_root.endswith('.ckpt'):
ckpt_path = ckpt_root
else:
if step is None:
ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root)))
if not ckpts:
return
steps = map(lambda x: int(x.split(".")[0]), ckpts)
step = max(steps)
ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt')
logging.info(f'resume from {ckpt_path}')
self.load(ckpt_path)
if adapter_path is not None:
self.load_adapter(adapter_path)
def to(self, device):
for key, val in self.__dict__.items():
if isinstance(val, nn.Module):
val.to(device)
def freeze(self):
self.nnet.requires_grad_(False)
for name, p in self.nnet.named_parameters():
if 'adapter' in name:
p.requires_grad_(True)
def cnt_params(model):
return sum(param.numel() for param in model.parameters())
def initialize_train_state(config, device):
params = []
nnet = get_nnet(**config.nnet)
params += nnet.adapter.parameters()
nnet_ema = get_nnet(**config.nnet)
nnet_ema.eval()
logging.info(f'nnet has {cnt_params(nnet)} parameters')
optimizer = get_optimizer(params, **config.optimizer)
lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler)
train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0,
nnet=nnet, nnet_ema=nnet_ema)
train_state.ema_update(0)
train_state.to(device)
return train_state
def amortize(n_samples, batch_size):
k = n_samples // batch_size
r = n_samples % batch_size
return k * [batch_size] if r == 0 else k * [batch_size] + [r]
def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, dist=True):
if path:
os.makedirs(path, exist_ok=True)
idx = 0
batch_size = mini_batch_size * accelerator.num_processes if dist else mini_batch_size
for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'):
samples = unpreprocess_fn(sample_fn(mini_batch_size))
if dist:
samples = accelerator.gather(samples.contiguous())[:_batch_size]
if accelerator.is_main_process:
for sample in samples:
save_image(sample, os.path.join(path, f"{idx}.png"))
idx += 1
def grad_norm(model):
total_norm = 0.
for p in model.parameters():
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** (1. / 2)
return total_norm
from collections import defaultdict, deque
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter=" "):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def add_meter(self, name, meter):
self.meters[name] = meter
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
from torch._six import inf
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
total_norm = torch.norm(
torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
return total_norm