AvatarArtist / DiT_VAE /vae /utils /common_utils.py
刘虹雨
update
8ed2f16
import os
import importlib
import numpy as np
from inspect import isfunction
import torch
def shape_to_str(x):
shape_str = "x".join([str(x) for x in x.shape])
return shape_str
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise ValueError('Boolean value expected.')
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
""" Shifts src_tf dim to dest dim
i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c)
"""
n_dims = len(x.shape)
if src_dim < 0:
src_dim = n_dims + src_dim
if dest_dim < 0:
dest_dim = n_dims + dest_dim
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
dims = list(range(n_dims))
del dims[src_dim]
permutation = []
ctr = 0
for i in range(n_dims):
if i == dest_dim:
permutation.append(src_dim)
else:
permutation.append(dims[ctr])
ctr += 1
x = x.permute(permutation)
if make_contiguous:
x = x.contiguous()
return x
def torch_to_np(x):
sample = x.detach().cpu()
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
if sample.dim() == 5:
sample = sample.permute(0, 2, 3, 4, 1)
else:
sample = sample.permute(0, 2, 3, 1)
sample = sample.contiguous().numpy()
return sample
def np_to_torch_video(x):
x = torch.tensor(x).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w]
x = (x / 255 - 0.5) * 2
return x
def load_npz_from_dir(data_dir):
data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)]
data = np.concatenate(data, axis=0)
return data
def load_npz_from_paths(data_paths):
data = [np.load(data_path)['arr_0'] for data_path in data_paths]
data = np.concatenate(data, axis=0)
return data
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x,torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
return total_params
def check_istarget(name, para_list):
"""
name: full name of source para
para_list: partial name of target para
"""
istarget=False
for para in para_list:
if para in name:
return True
return istarget