Spaces:
Sleeping
Sleeping
import pickle | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import os | |
from tqdm import tqdm | |
from torchvision.utils import save_image | |
from torch import distributed as dist | |
from loguru import logger | |
logging = logger | |
def set_logger(log_level='info', fname=None): | |
import logging as _logging | |
handler = logging.get_absl_handler() | |
formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s') | |
handler.setFormatter(formatter) | |
logging.set_verbosity(log_level) | |
if fname is not None: | |
handler = _logging.FileHandler(fname) | |
handler.setFormatter(formatter) | |
logging.get_absl_logger().addHandler(handler) | |
def dct2str(dct): | |
return str({k: f'{v:.6g}' for k, v in dct.items()}) | |
def get_nnet(name, **kwargs): | |
if name == 'uvit_t2i_vq': | |
from libs.uvit_t2i_vq import UViT | |
return UViT(**kwargs) | |
elif name == 'uvit_vq': | |
from libs.uvit_vq import UViT | |
return UViT(**kwargs) | |
else: | |
raise NotImplementedError(name) | |
def set_seed(seed: int): | |
if seed is not None: | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
def get_optimizer(params, name, **kwargs): | |
if name == 'adam': | |
from torch.optim import Adam | |
return Adam(params, **kwargs) | |
elif name == 'adamw': | |
from torch.optim import AdamW | |
return AdamW(params, **kwargs) | |
else: | |
raise NotImplementedError(name) | |
def customized_lr_scheduler(optimizer, warmup_steps=-1): | |
from torch.optim.lr_scheduler import LambdaLR | |
def fn(step): | |
if warmup_steps > 0: | |
return min(step / warmup_steps, 1) | |
else: | |
return 1 | |
return LambdaLR(optimizer, fn) | |
def get_lr_scheduler(optimizer, name, **kwargs): | |
if name == 'customized': | |
return customized_lr_scheduler(optimizer, **kwargs) | |
else: | |
raise NotImplementedError(name) | |
def ema(model_dest: nn.Module, model_src: nn.Module, rate): | |
param_dict_src = dict(model_src.named_parameters()) | |
for p_name, p_dest in model_dest.named_parameters(): | |
p_src = param_dict_src[p_name] | |
assert p_src is not p_dest | |
if 'adapter' not in p_name: | |
p_dest.data.mul_(rate).add_((1 - rate) * p_src.data) | |
else: | |
p_dest.data = p_src.detach().clone() | |
class TrainState(object): | |
def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None): | |
self.optimizer = optimizer | |
self.lr_scheduler = lr_scheduler | |
self.step = step | |
self.nnet = nnet | |
self.nnet_ema = nnet_ema | |
def ema_update(self, rate=0.9999): | |
if self.nnet_ema is not None: | |
ema(self.nnet_ema, self.nnet, rate) | |
def save(self, path, adapter_only=False,name=""): | |
os.makedirs(path, exist_ok=True) | |
torch.save(self.step, os.path.join(path, 'step.pth')) | |
if adapter_only: | |
torch.save(self.nnet.adapter.state_dict(), os.path.join(path, name+'adapter.pth')) | |
else: | |
for key, val in self.__dict__.items(): | |
if key != 'step' and val is not None: | |
torch.save(val.state_dict(), os.path.join(path, f'{key}.pth')) | |
def make_dict(self,model,state_dict): | |
state = {} | |
for k in model.state_dict().keys(): | |
if k in state_dict: | |
state[k] = state_dict[k].clone() | |
else: | |
state[k] = model.state_dict()[k].clone() | |
return state | |
def load(self, path): | |
logging.info(f'load from {path}') | |
self.step = torch.load(os.path.join(path, 'step.pth'), map_location='cpu') | |
for key, val in self.__dict__.items(): | |
if key != 'step' and val is not None and key != 'optimizer' and key != 'lr_scheduler': | |
if key == 'nnet' or key == 'nnet_ema': | |
val.load_state_dict(self.make_dict(val,torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu'))) | |
else: | |
val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu')) | |
def load_adapter(self,path): | |
logging.info('load adapter from {}'.format(path)) | |
adapter = torch.load(path,map_location='cpu') | |
keys=['nnet','nnet_ema'] | |
for key in keys: | |
if key in self.__dict__: | |
self.__dict__[key].adapter.load_state_dict(adapter) | |
else: | |
logging.info('adapter not in state_dict') | |
def resume(self, ckpt_root,adapter_path=None, step=None): | |
if not os.path.exists(ckpt_root): | |
return | |
if ckpt_root.endswith('.ckpt'): | |
ckpt_path = ckpt_root | |
else: | |
if step is None: | |
ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root))) | |
if not ckpts: | |
return | |
steps = map(lambda x: int(x.split(".")[0]), ckpts) | |
step = max(steps) | |
ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt') | |
logging.info(f'resume from {ckpt_path}') | |
self.load(ckpt_path) | |
if adapter_path is not None: | |
self.load_adapter(adapter_path) | |
def to(self, device): | |
for key, val in self.__dict__.items(): | |
if isinstance(val, nn.Module): | |
val.to(device) | |
def freeze(self): | |
self.nnet.requires_grad_(False) | |
for name, p in self.nnet.named_parameters(): | |
if 'adapter' in name: | |
p.requires_grad_(True) | |
def cnt_params(model): | |
return sum(param.numel() for param in model.parameters()) | |
def initialize_train_state(config, device): | |
params = [] | |
nnet = get_nnet(**config.nnet) | |
params += nnet.adapter.parameters() | |
nnet_ema = get_nnet(**config.nnet) | |
nnet_ema.eval() | |
logging.info(f'nnet has {cnt_params(nnet)} parameters') | |
optimizer = get_optimizer(params, **config.optimizer) | |
lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler) | |
train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0, | |
nnet=nnet, nnet_ema=nnet_ema) | |
train_state.ema_update(0) | |
train_state.to(device) | |
return train_state | |
def amortize(n_samples, batch_size): | |
k = n_samples // batch_size | |
r = n_samples % batch_size | |
return k * [batch_size] if r == 0 else k * [batch_size] + [r] | |
def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, dist=True): | |
if path: | |
os.makedirs(path, exist_ok=True) | |
idx = 0 | |
batch_size = mini_batch_size * accelerator.num_processes if dist else mini_batch_size | |
for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'): | |
samples = unpreprocess_fn(sample_fn(mini_batch_size)) | |
if dist: | |
samples = accelerator.gather(samples.contiguous())[:_batch_size] | |
if accelerator.is_main_process: | |
for sample in samples: | |
save_image(sample, os.path.join(path, f"{idx}.png")) | |
idx += 1 | |
def grad_norm(model): | |
total_norm = 0. | |
for p in model.parameters(): | |
param_norm = p.grad.data.norm(2) | |
total_norm += param_norm.item() ** 2 | |
total_norm = total_norm ** (1. / 2) | |
return total_norm | |
from collections import defaultdict, deque | |
class SmoothedValue(object): | |
"""Track a series of values and provide access to smoothed values over a | |
window or the global series average. | |
""" | |
def __init__(self, window_size=20, fmt=None): | |
if fmt is None: | |
fmt = "{median:.4f} ({global_avg:.4f})" | |
self.deque = deque(maxlen=window_size) | |
self.total = 0.0 | |
self.count = 0 | |
self.fmt = fmt | |
def update(self, value, n=1): | |
self.deque.append(value) | |
self.count += n | |
self.total += value * n | |
def median(self): | |
d = torch.tensor(list(self.deque)) | |
return d.median().item() | |
def avg(self): | |
d = torch.tensor(list(self.deque), dtype=torch.float32) | |
return d.mean().item() | |
def global_avg(self): | |
return self.total / self.count | |
def max(self): | |
return max(self.deque) | |
def value(self): | |
return self.deque[-1] | |
def __str__(self): | |
return self.fmt.format( | |
median=self.median, | |
avg=self.avg, | |
global_avg=self.global_avg, | |
max=self.max, | |
value=self.value) | |
class MetricLogger(object): | |
def __init__(self, delimiter=" "): | |
self.meters = defaultdict(SmoothedValue) | |
self.delimiter = delimiter | |
def update(self, **kwargs): | |
for k, v in kwargs.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.item() | |
assert isinstance(v, (float, int)) | |
self.meters[k].update(v) | |
def __getattr__(self, attr): | |
if attr in self.meters: | |
return self.meters[attr] | |
if attr in self.__dict__: | |
return self.__dict__[attr] | |
raise AttributeError("'{}' object has no attribute '{}'".format( | |
type(self).__name__, attr)) | |
def __str__(self): | |
loss_str = [] | |
for name, meter in self.meters.items(): | |
loss_str.append( | |
"{}: {}".format(name, str(meter)) | |
) | |
return self.delimiter.join(loss_str) | |
def add_meter(self, name, meter): | |
self.meters[name] = meter | |
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: | |
from torch._six import inf | |
if isinstance(parameters, torch.Tensor): | |
parameters = [parameters] | |
parameters = [p for p in parameters if p.grad is not None] | |
norm_type = float(norm_type) | |
if len(parameters) == 0: | |
return torch.tensor(0.) | |
device = parameters[0].grad.device | |
if norm_type == inf: | |
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) | |
else: | |
total_norm = torch.norm( | |
torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) | |
return total_norm | |